1mod access_token;
11mod refresh_token;
12mod session;
13mod sso_login;
14
15pub use self::{
16 access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
17 session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
18};
19
20#[cfg(test)]
21mod tests {
22 use chrono::Duration;
23 use mas_data_model::{Clock, Device, clock::MockClock};
24 use mas_storage::{
25 Pagination, RepositoryAccess,
26 compat::{
27 CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
28 CompatSessionRepository, CompatSsoLoginFilter,
29 },
30 user::UserRepository,
31 };
32 use rand::SeedableRng;
33 use rand_chacha::ChaChaRng;
34 use sqlx::PgPool;
35 use ulid::Ulid;
36
37 use crate::PgRepository;
38
39 #[sqlx::test(migrator = "crate::MIGRATOR")]
40 async fn test_session_repository(pool: PgPool) {
41 let mut rng = ChaChaRng::seed_from_u64(42);
42 let clock = MockClock::default();
43 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
44
45 let user = repo
47 .user()
48 .add(&mut rng, &clock, "john".to_owned())
49 .await
50 .unwrap();
51
52 let all = CompatSessionFilter::new().for_user(&user);
53 let active = all.active_only();
54 let finished = all.finished_only();
55 let pagination = Pagination::first(10);
56
57 assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
58 assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
59 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
60
61 let full_list = repo.compat_session().list(all, pagination).await.unwrap();
62 assert!(full_list.edges.is_empty());
63 let active_list = repo
64 .compat_session()
65 .list(active, pagination)
66 .await
67 .unwrap();
68 assert!(active_list.edges.is_empty());
69 let finished_list = repo
70 .compat_session()
71 .list(finished, pagination)
72 .await
73 .unwrap();
74 assert!(finished_list.edges.is_empty());
75
76 let device = Device::generate(&mut rng);
78 let device_str = device.as_str().to_owned();
79 let session = repo
80 .compat_session()
81 .add(&mut rng, &clock, &user, device.clone(), None, false, None)
82 .await
83 .unwrap();
84 assert_eq!(session.user_id, user.id);
85 assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
86 assert!(session.is_valid());
87 assert!(!session.is_finished());
88
89 assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
90 assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
91 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
92
93 let full_list = repo.compat_session().list(all, pagination).await.unwrap();
94 assert_eq!(full_list.edges.len(), 1);
95 assert_eq!(full_list.edges[0].node.0.id, session.id);
96 let active_list = repo
97 .compat_session()
98 .list(active, pagination)
99 .await
100 .unwrap();
101 assert_eq!(active_list.edges.len(), 1);
102 assert_eq!(active_list.edges[0].node.0.id, session.id);
103 let finished_list = repo
104 .compat_session()
105 .list(finished, pagination)
106 .await
107 .unwrap();
108 assert!(finished_list.edges.is_empty());
109
110 let session_lookup = repo
112 .compat_session()
113 .lookup(session.id)
114 .await
115 .unwrap()
116 .expect("compat session not found");
117 assert_eq!(session_lookup.id, session.id);
118 assert_eq!(session_lookup.user_id, user.id);
119 assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
120 assert!(session_lookup.is_valid());
121 assert!(!session_lookup.is_finished());
122
123 assert!(session_lookup.user_agent.is_none());
125 let session = repo
126 .compat_session()
127 .record_user_agent(session_lookup, "Mozilla/5.0".to_owned())
128 .await
129 .unwrap();
130 assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
131
132 let session_lookup = repo
134 .compat_session()
135 .lookup(session.id)
136 .await
137 .unwrap()
138 .expect("compat session not found");
139 assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
140
141 let list = repo
143 .compat_session()
144 .list(
145 CompatSessionFilter::new()
146 .for_user(&user)
147 .for_device(&device),
148 pagination,
149 )
150 .await
151 .unwrap();
152 assert_eq!(list.edges.len(), 1);
153 let session_lookup = &list.edges[0].node.0;
154 assert_eq!(session_lookup.id, session.id);
155 assert_eq!(session_lookup.user_id, user.id);
156 assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
157 assert!(session_lookup.is_valid());
158 assert!(!session_lookup.is_finished());
159
160 let session = repo.compat_session().finish(&clock, session).await.unwrap();
162 assert!(!session.is_valid());
163 assert!(session.is_finished());
164
165 assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
166 assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
167 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
168
169 let full_list = repo.compat_session().list(all, pagination).await.unwrap();
170 assert_eq!(full_list.edges.len(), 1);
171 assert_eq!(full_list.edges[0].node.0.id, session.id);
172 let active_list = repo
173 .compat_session()
174 .list(active, pagination)
175 .await
176 .unwrap();
177 assert!(active_list.edges.is_empty());
178 let finished_list = repo
179 .compat_session()
180 .list(finished, pagination)
181 .await
182 .unwrap();
183 assert_eq!(finished_list.edges.len(), 1);
184 assert_eq!(finished_list.edges[0].node.0.id, session.id);
185
186 let session_lookup = repo
188 .compat_session()
189 .lookup(session.id)
190 .await
191 .unwrap()
192 .expect("compat session not found");
193 assert!(!session_lookup.is_valid());
194 assert!(session_lookup.is_finished());
195
196 let unknown_session = session;
198 let login = repo
200 .compat_sso_login()
201 .add(
202 &mut rng,
203 &clock,
204 "login-token".to_owned(),
205 "https://example.com/callback".parse().unwrap(),
206 )
207 .await
208 .unwrap();
209 assert!(login.is_pending());
210
211 let browser_session = repo
213 .browser_session()
214 .add(&mut rng, &clock, &user, None)
215 .await
216 .unwrap();
217
218 let device = Device::generate(&mut rng);
220 let sso_login_session = repo
221 .compat_session()
222 .add(
223 &mut rng,
224 &clock,
225 &user,
226 device,
227 Some(&browser_session),
228 false,
229 None,
230 )
231 .await
232 .unwrap();
233
234 let login = repo
236 .compat_sso_login()
237 .fulfill(&clock, login, &browser_session)
238 .await
239 .unwrap();
240 assert!(login.is_fulfilled());
241 let login = repo
242 .compat_sso_login()
243 .exchange(&clock, login, &sso_login_session)
244 .await
245 .unwrap();
246 assert!(login.is_exchanged());
247
248 let all = CompatSessionFilter::new().for_user(&user);
251 let sso_login = all.sso_login_only();
252 let unknown = all.unknown_only();
253 assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
254 assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
255 assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
256
257 let list = repo
258 .compat_session()
259 .list(sso_login, pagination)
260 .await
261 .unwrap();
262 assert_eq!(list.edges.len(), 1);
263 assert_eq!(list.edges[0].node.0.id, sso_login_session.id);
264 let list = repo
265 .compat_session()
266 .list(unknown, pagination)
267 .await
268 .unwrap();
269 assert_eq!(list.edges.len(), 1);
270 assert_eq!(list.edges[0].node.0.id, unknown_session.id);
271
272 assert_eq!(
276 repo.compat_session()
277 .count(all.sso_login_only().active_only())
278 .await
279 .unwrap(),
280 1
281 );
282 assert_eq!(
283 repo.compat_session()
284 .count(all.sso_login_only().finished_only())
285 .await
286 .unwrap(),
287 0
288 );
289 assert_eq!(
290 repo.compat_session()
291 .count(all.unknown_only().active_only())
292 .await
293 .unwrap(),
294 0
295 );
296 assert_eq!(
297 repo.compat_session()
298 .count(all.unknown_only().finished_only())
299 .await
300 .unwrap(),
301 1
302 );
303
304 let affected = repo
306 .compat_session()
307 .finish_bulk(&clock, all.sso_login_only().active_only())
308 .await
309 .unwrap();
310 assert_eq!(affected, 1);
311 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
312 assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
313 }
314
315 #[sqlx::test(migrator = "crate::MIGRATOR")]
316 async fn test_access_token_repository(pool: PgPool) {
317 const FIRST_TOKEN: &str = "first_access_token";
318 const SECOND_TOKEN: &str = "second_access_token";
319 let mut rng = ChaChaRng::seed_from_u64(42);
320 let clock = MockClock::default();
321 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
322
323 let user = repo
325 .user()
326 .add(&mut rng, &clock, "john".to_owned())
327 .await
328 .unwrap();
329
330 let device = Device::generate(&mut rng);
332 let session = repo
333 .compat_session()
334 .add(&mut rng, &clock, &user, device, None, false, None)
335 .await
336 .unwrap();
337
338 let token = repo
340 .compat_access_token()
341 .add(
342 &mut rng,
343 &clock,
344 &session,
345 FIRST_TOKEN.to_owned(),
346 Some(Duration::try_minutes(1).unwrap()),
347 )
348 .await
349 .unwrap();
350 assert_eq!(token.session_id, session.id);
351 assert_eq!(token.token, FIRST_TOKEN);
352
353 repo.save().await.unwrap();
355
356 {
357 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
358 assert!(
360 repo.compat_access_token()
361 .add(
362 &mut rng,
363 &clock,
364 &session,
365 FIRST_TOKEN.to_owned(),
366 Some(Duration::try_minutes(1).unwrap()),
367 )
368 .await
369 .is_err()
370 );
371 repo.cancel().await.unwrap();
372 }
373
374 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
376
377 let token_lookup = repo
379 .compat_access_token()
380 .lookup(token.id)
381 .await
382 .unwrap()
383 .expect("compat access token not found");
384 assert_eq!(token.id, token_lookup.id);
385 assert_eq!(token_lookup.session_id, session.id);
386
387 let token_lookup = repo
389 .compat_access_token()
390 .find_by_token(FIRST_TOKEN)
391 .await
392 .unwrap()
393 .expect("compat access token not found");
394 assert_eq!(token.id, token_lookup.id);
395 assert_eq!(token_lookup.session_id, session.id);
396
397 assert!(token.is_valid(clock.now()));
399
400 clock.advance(Duration::try_minutes(1).unwrap());
401 assert!(!token.is_valid(clock.now()));
403
404 let token = repo
406 .compat_access_token()
407 .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
408 .await
409 .unwrap();
410 assert_eq!(token.session_id, session.id);
411 assert_eq!(token.token, SECOND_TOKEN);
412
413 assert!(token.is_valid(clock.now()));
415
416 repo.compat_access_token()
418 .expire(&clock, token)
419 .await
420 .unwrap();
421
422 let token = repo
424 .compat_access_token()
425 .find_by_token(SECOND_TOKEN)
426 .await
427 .unwrap()
428 .expect("compat access token not found");
429
430 assert!(!token.is_valid(clock.now()));
432
433 repo.save().await.unwrap();
434 }
435
436 #[sqlx::test(migrator = "crate::MIGRATOR")]
437 async fn test_refresh_token_repository(pool: PgPool) {
438 const ACCESS_TOKEN: &str = "access_token";
439 const REFRESH_TOKEN: &str = "refresh_token";
440 const REFRESH_TOKEN2: &str = "refresh_token2";
441 let mut rng = ChaChaRng::seed_from_u64(42);
442 let clock = MockClock::default();
443 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
444
445 let user = repo
447 .user()
448 .add(&mut rng, &clock, "john".to_owned())
449 .await
450 .unwrap();
451
452 let device = Device::generate(&mut rng);
454 let session = repo
455 .compat_session()
456 .add(&mut rng, &clock, &user, device, None, false, None)
457 .await
458 .unwrap();
459
460 let access_token = repo
462 .compat_access_token()
463 .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
464 .await
465 .unwrap();
466
467 let refresh_token = repo
468 .compat_refresh_token()
469 .add(
470 &mut rng,
471 &clock,
472 &session,
473 &access_token,
474 REFRESH_TOKEN.to_owned(),
475 )
476 .await
477 .unwrap();
478 assert_eq!(refresh_token.session_id, session.id);
479 assert_eq!(refresh_token.access_token_id, access_token.id);
480 assert_eq!(refresh_token.token, REFRESH_TOKEN);
481 assert!(refresh_token.is_valid());
482 assert!(!refresh_token.is_consumed());
483
484 let refresh_token_lookup = repo
486 .compat_refresh_token()
487 .lookup(refresh_token.id)
488 .await
489 .unwrap()
490 .expect("refresh token not found");
491 assert_eq!(refresh_token_lookup.id, refresh_token.id);
492 assert_eq!(refresh_token_lookup.session_id, session.id);
493 assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
494 assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
495 assert!(refresh_token_lookup.is_valid());
496 assert!(!refresh_token_lookup.is_consumed());
497
498 let refresh_token_lookup = repo
500 .compat_refresh_token()
501 .find_by_token(REFRESH_TOKEN)
502 .await
503 .unwrap()
504 .expect("refresh token not found");
505 assert_eq!(refresh_token_lookup.id, refresh_token.id);
506 assert_eq!(refresh_token_lookup.session_id, session.id);
507 assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
508 assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
509 assert!(refresh_token_lookup.is_valid());
510 assert!(!refresh_token_lookup.is_consumed());
511
512 let refresh_token2 = repo
514 .compat_refresh_token()
515 .add(
516 &mut rng,
517 &clock,
518 &session,
519 &access_token,
520 REFRESH_TOKEN2.to_owned(),
521 )
522 .await
523 .unwrap();
524
525 let refresh_token = repo
526 .compat_refresh_token()
527 .consume_and_replace(&clock, refresh_token, &refresh_token2)
528 .await
529 .unwrap();
530 assert!(!refresh_token.is_valid());
531 assert!(refresh_token.is_consumed());
532
533 let refresh_token_lookup = repo
535 .compat_refresh_token()
536 .find_by_token(REFRESH_TOKEN)
537 .await
538 .unwrap()
539 .expect("refresh token not found");
540 assert!(!refresh_token_lookup.is_valid());
541 assert!(refresh_token_lookup.is_consumed());
542
543 assert!(
545 repo.compat_refresh_token()
546 .consume_and_replace(&clock, refresh_token, &refresh_token2)
547 .await
548 .is_err()
549 );
550
551 repo.save().await.unwrap();
552 }
553
554 #[sqlx::test(migrator = "crate::MIGRATOR")]
555 async fn test_compat_sso_login_repository(pool: PgPool) {
556 let mut rng = ChaChaRng::seed_from_u64(42);
557 let clock = MockClock::default();
558 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
559
560 let user = repo
562 .user()
563 .add(&mut rng, &clock, "john".to_owned())
564 .await
565 .unwrap();
566
567 let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
569 assert_eq!(login, None);
570
571 let all = CompatSsoLoginFilter::new();
572 let for_user = all.for_user(&user);
573 let pending = all.pending_only();
574 let fulfilled = all.fulfilled_only();
575 let exchanged = all.exchanged_only();
576
577 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
579 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
580 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
581 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
582 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
583
584 let login = repo
586 .compat_sso_login()
587 .find_by_token("login-token")
588 .await
589 .unwrap();
590 assert_eq!(login, None);
591
592 let login = repo
594 .compat_sso_login()
595 .add(
596 &mut rng,
597 &clock,
598 "login-token".to_owned(),
599 "https://example.com/callback".parse().unwrap(),
600 )
601 .await
602 .unwrap();
603 assert!(login.is_pending());
604
605 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
607 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
608 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
609 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
610 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
611
612 let login_lookup = repo
614 .compat_sso_login()
615 .lookup(login.id)
616 .await
617 .unwrap()
618 .expect("login not found");
619 assert_eq!(login_lookup, login);
620
621 let login_lookup = repo
623 .compat_sso_login()
624 .find_by_token("login-token")
625 .await
626 .unwrap()
627 .expect("login not found");
628 assert_eq!(login_lookup, login);
629
630 let device = Device::generate(&mut rng);
632 let compat_session = repo
633 .compat_session()
634 .add(&mut rng, &clock, &user, device, None, false, None)
635 .await
636 .unwrap();
637
638 let res = repo
641 .compat_sso_login()
642 .exchange(&clock, login.clone(), &compat_session)
643 .await;
644 assert!(res.is_err());
645
646 let browser_session = repo
648 .browser_session()
649 .add(&mut rng, &clock, &user, None)
650 .await
651 .unwrap();
652
653 let login = repo
655 .compat_sso_login()
656 .fulfill(&clock, login, &browser_session)
657 .await
658 .unwrap();
659 assert!(login.is_fulfilled());
660
661 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
663 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
664 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
665 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
666 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
667
668 let res = repo
671 .compat_sso_login()
672 .fulfill(&clock, login.clone(), &browser_session)
673 .await;
674 assert!(res.is_err());
675
676 let login = repo
678 .compat_sso_login()
679 .exchange(&clock, login, &compat_session)
680 .await
681 .unwrap();
682 assert!(login.is_exchanged());
683
684 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
686 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
687 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
688 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
689 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
690
691 let res = repo
694 .compat_sso_login()
695 .exchange(&clock, login.clone(), &compat_session)
696 .await;
697 assert!(res.is_err());
698
699 let res = repo
702 .compat_sso_login()
703 .fulfill(&clock, login.clone(), &browser_session)
704 .await;
705 assert!(res.is_err());
706
707 let pagination = Pagination::first(10);
708
709 let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
711 assert!(!logins.has_next_page);
712 assert_eq!(logins.edges.len(), 1);
713 assert_eq!(logins.edges[0].node, login);
714
715 let logins = repo
717 .compat_sso_login()
718 .list(for_user, pagination)
719 .await
720 .unwrap();
721 assert!(!logins.has_next_page);
722 assert_eq!(logins.edges.len(), 1);
723 assert_eq!(logins.edges[0].node, login);
724
725 let logins = repo
727 .compat_sso_login()
728 .list(for_user.pending_only(), pagination)
729 .await
730 .unwrap();
731 assert!(!logins.has_next_page);
732 assert!(logins.edges.is_empty());
733
734 let logins = repo
736 .compat_sso_login()
737 .list(for_user.fulfilled_only(), pagination)
738 .await
739 .unwrap();
740 assert!(!logins.has_next_page);
741 assert!(logins.edges.is_empty());
742
743 let logins = repo
745 .compat_sso_login()
746 .list(for_user.exchanged_only(), pagination)
747 .await
748 .unwrap();
749 assert!(!logins.has_next_page);
750 assert_eq!(logins.edges.len(), 1);
751 assert_eq!(logins.edges[0].node, login);
752 }
753}