mod access_token;
mod refresh_token;
mod session;
mod sso_login;
pub use self::{
access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::{Device, UserAgent};
use mas_storage::{
clock::MockClock,
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
CompatSessionRepository, CompatSsoLoginFilter,
},
user::UserRepository,
Clock, Pagination, RepositoryAccess,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use ulid::Ulid;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_session_repository(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let all = CompatSessionFilter::new().for_user(&user);
let active = all.active_only();
let finished = all.finished_only();
let pagination = Pagination::first(10);
assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert!(full_list.edges.is_empty());
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert!(finished_list.edges.is_empty());
let device = Device::generate(&mut rng);
let device_str = device.as_str().to_owned();
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device.clone(), None, false)
.await
.unwrap();
assert_eq!(session.user_id, user.id);
assert_eq!(session.device.as_str(), device_str);
assert!(session.is_valid());
assert!(!session.is_finished());
assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(full_list.edges[0].0.id, session.id);
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert_eq!(active_list.edges.len(), 1);
assert_eq!(active_list.edges[0].0.id, session.id);
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert!(finished_list.edges.is_empty());
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user_id, user.id);
assert_eq!(session_lookup.device.as_str(), device_str);
assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished());
assert!(session_lookup.user_agent.is_none());
let session = repo
.compat_session()
.record_user_agent(session_lookup, UserAgent::parse("Mozilla/5.0".to_owned()))
.await
.unwrap();
assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
let list = repo
.compat_session()
.list(
CompatSessionFilter::new()
.for_user(&user)
.for_device(&device),
pagination,
)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
let session_lookup = &list.edges[0].0;
assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user_id, user.id);
assert_eq!(session_lookup.device.as_str(), device_str);
assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished());
let session = repo.compat_session().finish(&clock, session).await.unwrap();
assert!(!session.is_valid());
assert!(session.is_finished());
assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(full_list.edges[0].0.id, session.id);
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert_eq!(finished_list.edges.len(), 1);
assert_eq!(finished_list.edges[0].0.id, session.id);
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert!(!session_lookup.is_valid());
assert!(session_lookup.is_finished());
let unknown_session = session;
let login = repo
.compat_sso_login()
.add(
&mut rng,
&clock,
"login-token".to_owned(),
"https://example.com/callback".parse().unwrap(),
)
.await
.unwrap();
assert!(login.is_pending());
let device = Device::generate(&mut rng);
let sso_login_session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
let login = repo
.compat_sso_login()
.fulfill(&clock, login, &sso_login_session)
.await
.unwrap();
assert!(login.is_fulfilled());
let all = CompatSessionFilter::new().for_user(&user);
let sso_login = all.sso_login_only();
let unknown = all.unknown_only();
assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
let list = repo
.compat_session()
.list(sso_login, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].0.id, sso_login_session.id);
let list = repo
.compat_session()
.list(unknown, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].0.id, unknown_session.id);
assert_eq!(
repo.compat_session()
.count(all.sso_login_only().active_only())
.await
.unwrap(),
1
);
assert_eq!(
repo.compat_session()
.count(all.sso_login_only().finished_only())
.await
.unwrap(),
0
);
assert_eq!(
repo.compat_session()
.count(all.unknown_only().active_only())
.await
.unwrap(),
0
);
assert_eq!(
repo.compat_session()
.count(all.unknown_only().finished_only())
.await
.unwrap(),
1
);
let affected = repo
.compat_session()
.finish_bulk(&clock, all.sso_login_only().active_only())
.await
.unwrap();
assert_eq!(affected, 1);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_access_token_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
let token = repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::try_minutes(1).unwrap()),
)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, FIRST_TOKEN);
repo.save().await.unwrap();
{
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
assert!(repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::try_minutes(1).unwrap()),
)
.await
.is_err());
repo.cancel().await.unwrap();
}
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let token_lookup = repo
.compat_access_token()
.lookup(token.id)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
let token_lookup = repo
.compat_access_token()
.find_by_token(FIRST_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
assert!(token.is_valid(clock.now()));
clock.advance(Duration::try_minutes(1).unwrap());
assert!(!token.is_valid(clock.now()));
let token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, SECOND_TOKEN);
assert!(token.is_valid(clock.now()));
repo.compat_access_token()
.expire(&clock, token)
.await
.unwrap();
let token = repo
.compat_access_token()
.find_by_token(SECOND_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
assert!(!token.is_valid(clock.now()));
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_refresh_token_repository(pool: PgPool) {
const ACCESS_TOKEN: &str = "access_token";
const REFRESH_TOKEN: &str = "refresh_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
let access_token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
.await
.unwrap();
let refresh_token = repo
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
REFRESH_TOKEN.to_owned(),
)
.await
.unwrap();
assert_eq!(refresh_token.session_id, session.id);
assert_eq!(refresh_token.access_token_id, access_token.id);
assert_eq!(refresh_token.token, REFRESH_TOKEN);
assert!(refresh_token.is_valid());
assert!(!refresh_token.is_consumed());
let refresh_token_lookup = repo
.compat_refresh_token()
.lookup(refresh_token.id)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
let refresh_token = repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());
assert!(refresh_token.is_consumed());
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert!(!refresh_token_lookup.is_valid());
assert!(refresh_token_lookup.is_consumed());
assert!(repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.is_err());
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_compat_sso_login_repository(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
assert_eq!(login, None);
let all = CompatSsoLoginFilter::new();
let for_user = all.for_user(&user);
let pending = all.pending_only();
let fulfilled = all.fulfilled_only();
let exchanged = all.exchanged_only();
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
let login = repo
.compat_sso_login()
.find_by_token("login-token")
.await
.unwrap();
assert_eq!(login, None);
let login = repo
.compat_sso_login()
.add(
&mut rng,
&clock,
"login-token".to_owned(),
"https://example.com/callback".parse().unwrap(),
)
.await
.unwrap();
assert!(login.is_pending());
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
let login_lookup = repo
.compat_sso_login()
.lookup(login.id)
.await
.unwrap()
.expect("login not found");
assert_eq!(login_lookup, login);
let login_lookup = repo
.compat_sso_login()
.find_by_token("login-token")
.await
.unwrap()
.expect("login not found");
assert_eq!(login_lookup, login);
let res = repo
.compat_sso_login()
.exchange(&clock, login.clone())
.await;
assert!(res.is_err());
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
let login = repo
.compat_sso_login()
.fulfill(&clock, login, &session)
.await
.unwrap();
assert!(login.is_fulfilled());
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
let res = repo
.compat_sso_login()
.fulfill(&clock, login.clone(), &session)
.await;
assert!(res.is_err());
let login = repo
.compat_sso_login()
.exchange(&clock, login)
.await
.unwrap();
assert!(login.is_exchanged());
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
let res = repo
.compat_sso_login()
.exchange(&clock, login.clone())
.await;
assert!(res.is_err());
let res = repo
.compat_sso_login()
.fulfill(&clock, login.clone(), &session)
.await;
assert!(res.is_err());
let pagination = Pagination::first(10);
let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login.clone()]);
let logins = repo
.compat_sso_login()
.list(for_user, pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login.clone()]);
let logins = repo
.compat_sso_login()
.list(for_user.pending_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert!(logins.edges.is_empty());
let logins = repo
.compat_sso_login()
.list(for_user.fulfilled_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert!(logins.edges.is_empty());
let logins = repo
.compat_sso_login()
.list(for_user.exchanged_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login]);
}
}