mod access_token;
mod authorization_grant;
mod client;
mod device_code_grant;
mod refresh_token;
mod session;
pub use self::{
access_token::PgOAuth2AccessTokenRepository,
authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository,
device_code_grant::PgOAuth2DeviceCodeGrantRepository,
refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::{AuthorizationCode, UserAgent};
use mas_storage::{
clock::MockClock,
oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository},
Clock, Pagination,
};
use oauth2_types::{
requests::{GrantType, ResponseMode},
scope::{Scope, EMAIL, OPENID, PROFILE},
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use ulid::Ulid;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repositories(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let client = repo.oauth2_client().lookup(Ulid::nil()).await.unwrap();
assert_eq!(client, None);
let client = repo
.oauth2_client()
.find_by_client_id("some-client-id")
.await
.unwrap();
assert_eq!(client, None);
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
vec!["https://example.com/redirect".parse().unwrap()],
None,
None,
vec![GrantType::AuthorizationCode],
Some("Test client".to_owned()),
Some("https://example.com/logo.png".parse().unwrap()),
Some("https://example.com/".parse().unwrap()),
Some("https://example.com/policy".parse().unwrap()),
Some("https://example.com/tos".parse().unwrap()),
Some("https://example.com/jwks.json".parse().unwrap()),
None,
None,
None,
None,
None,
Some("https://example.com/login".parse().unwrap()),
)
.await
.unwrap();
let client_lookup = repo
.oauth2_client()
.lookup(client.id)
.await
.unwrap()
.expect("client not found");
assert_eq!(client, client_lookup);
let client_lookup = repo
.oauth2_client()
.find_by_client_id(&client.client_id)
.await
.unwrap()
.expect("client not found");
assert_eq!(client, client_lookup);
let grant = repo
.oauth2_authorization_grant()
.lookup(Ulid::nil())
.await
.unwrap();
assert_eq!(grant, None);
let grant = repo
.oauth2_authorization_grant()
.find_by_code("code")
.await
.unwrap();
assert_eq!(grant, None);
let grant = repo
.oauth2_authorization_grant()
.add(
&mut rng,
&clock,
&client,
"https://example.com/redirect".parse().unwrap(),
Scope::from_iter([OPENID]),
Some(AuthorizationCode {
code: "code".to_owned(),
pkce: None,
}),
Some("state".to_owned()),
Some("nonce".to_owned()),
None,
ResponseMode::Query,
true,
false,
)
.await
.unwrap();
assert!(grant.is_pending());
let grant_lookup = repo
.oauth2_authorization_grant()
.lookup(grant.id)
.await
.unwrap()
.expect("grant not found");
assert_eq!(grant, grant_lookup);
let grant_lookup = repo
.oauth2_authorization_grant()
.find_by_code("code")
.await
.unwrap()
.expect("grant not found");
assert_eq!(grant, grant_lookup);
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let user_session = repo
.browser_session()
.add(&mut rng, &clock, &user, None)
.await
.unwrap();
let consent = repo
.oauth2_client()
.get_consent_for_user(&client, &user)
.await
.unwrap();
assert!(consent.is_empty());
let scope = Scope::from_iter([OPENID]);
repo.oauth2_client()
.give_consent_for_user(&mut rng, &clock, &client, &user, &scope)
.await
.unwrap();
let consent = repo
.oauth2_client()
.get_consent_for_user(&client, &user)
.await
.unwrap();
assert_eq!(scope, consent);
let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap();
assert_eq!(session, None);
let session = repo
.oauth2_session()
.add_from_browser_session(
&mut rng,
&clock,
&client,
&user_session,
grant.scope.clone(),
)
.await
.unwrap();
let grant = repo
.oauth2_authorization_grant()
.fulfill(&clock, &session, grant)
.await
.unwrap();
assert!(grant.is_fulfilled());
let session_lookup = repo
.oauth2_session()
.lookup(session.id)
.await
.unwrap()
.expect("session not found");
assert_eq!(session, session_lookup);
let grant = repo
.oauth2_authorization_grant()
.exchange(&clock, grant)
.await
.unwrap();
assert!(grant.is_exchanged());
let token = repo
.oauth2_access_token()
.lookup(Ulid::nil())
.await
.unwrap();
assert_eq!(token, None);
let token = repo
.oauth2_access_token()
.find_by_token("aabbcc")
.await
.unwrap();
assert_eq!(token, None);
let access_token = repo
.oauth2_access_token()
.add(
&mut rng,
&clock,
&session,
"aabbcc".to_owned(),
Some(Duration::try_minutes(5).unwrap()),
)
.await
.unwrap();
let access_token_lookup = repo
.oauth2_access_token()
.lookup(access_token.id)
.await
.unwrap()
.expect("token not found");
assert_eq!(access_token, access_token_lookup);
let access_token_lookup = repo
.oauth2_access_token()
.find_by_token("aabbcc")
.await
.unwrap()
.expect("token not found");
assert_eq!(access_token, access_token_lookup);
let refresh_token = repo
.oauth2_refresh_token()
.lookup(Ulid::nil())
.await
.unwrap();
assert_eq!(refresh_token, None);
let refresh_token = repo
.oauth2_refresh_token()
.find_by_token("aabbcc")
.await
.unwrap();
assert_eq!(refresh_token, None);
let refresh_token = repo
.oauth2_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
"aabbcc".to_owned(),
)
.await
.unwrap();
let refresh_token_lookup = repo
.oauth2_refresh_token()
.lookup(refresh_token.id)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token, refresh_token_lookup);
let refresh_token_lookup = repo
.oauth2_refresh_token()
.find_by_token("aabbcc")
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token, refresh_token_lookup);
assert!(access_token.is_valid(clock.now()));
clock.advance(Duration::try_minutes(6).unwrap());
assert!(!access_token.is_valid(clock.now()));
clock.advance(Duration::try_minutes(-6).unwrap()); assert!(access_token.is_valid(clock.now()));
let access_token = repo
.oauth2_access_token()
.revoke(&clock, access_token)
.await
.unwrap();
assert!(!access_token.is_valid(clock.now()));
assert!(refresh_token.is_valid());
let refresh_token = repo
.oauth2_refresh_token()
.consume(&clock, refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());
assert!(session.user_agent.is_none());
let session = repo
.oauth2_session()
.record_user_agent(session, UserAgent::parse("Mozilla/5.0".to_owned()))
.await
.unwrap();
assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
let session = repo
.oauth2_session()
.lookup(session.id)
.await
.unwrap()
.expect("session not found");
assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
assert!(session.is_valid());
let session = repo.oauth2_session().finish(&clock, session).await.unwrap();
assert!(!session.is_valid());
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_list_sessions(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let user1 = repo
.user()
.add(&mut rng, &clock, "alice".to_owned())
.await
.unwrap();
let user1_session = repo
.browser_session()
.add(&mut rng, &clock, &user1, None)
.await
.unwrap();
let user2 = repo
.user()
.add(&mut rng, &clock, "bob".to_owned())
.await
.unwrap();
let user2_session = repo
.browser_session()
.add(&mut rng, &clock, &user2, None)
.await
.unwrap();
let client1 = repo
.oauth2_client()
.add(
&mut rng,
&clock,
vec!["https://first.example.com/redirect".parse().unwrap()],
None,
None,
vec![GrantType::AuthorizationCode],
Some("First client".to_owned()),
Some("https://first.example.com/logo.png".parse().unwrap()),
Some("https://first.example.com/".parse().unwrap()),
Some("https://first.example.com/policy".parse().unwrap()),
Some("https://first.example.com/tos".parse().unwrap()),
Some("https://first.example.com/jwks.json".parse().unwrap()),
None,
None,
None,
None,
None,
Some("https://first.example.com/login".parse().unwrap()),
)
.await
.unwrap();
let client2 = repo
.oauth2_client()
.add(
&mut rng,
&clock,
vec!["https://second.example.com/redirect".parse().unwrap()],
None,
None,
vec![GrantType::AuthorizationCode],
Some("Second client".to_owned()),
Some("https://second.example.com/logo.png".parse().unwrap()),
Some("https://second.example.com/".parse().unwrap()),
Some("https://second.example.com/policy".parse().unwrap()),
Some("https://second.example.com/tos".parse().unwrap()),
Some("https://second.example.com/jwks.json".parse().unwrap()),
None,
None,
None,
None,
None,
Some("https://second.example.com/login".parse().unwrap()),
)
.await
.unwrap();
let scope = Scope::from_iter([OPENID, EMAIL]);
let scope2 = Scope::from_iter([OPENID, PROFILE]);
let session11 = repo
.oauth2_session()
.add_from_browser_session(&mut rng, &clock, &client1, &user1_session, scope.clone())
.await
.unwrap();
clock.advance(Duration::try_minutes(1).unwrap());
let session12 = repo
.oauth2_session()
.add_from_browser_session(&mut rng, &clock, &client1, &user2_session, scope.clone())
.await
.unwrap();
clock.advance(Duration::try_minutes(1).unwrap());
let session21 = repo
.oauth2_session()
.add_from_browser_session(&mut rng, &clock, &client2, &user1_session, scope2.clone())
.await
.unwrap();
clock.advance(Duration::try_minutes(1).unwrap());
let session22 = repo
.oauth2_session()
.add_from_browser_session(&mut rng, &clock, &client2, &user2_session, scope2.clone())
.await
.unwrap();
clock.advance(Duration::try_minutes(1).unwrap());
let session11 = repo
.oauth2_session()
.finish(&clock, session11)
.await
.unwrap();
let session22 = repo
.oauth2_session()
.finish(&clock, session22)
.await
.unwrap();
let pagination = Pagination::first(10);
let filter = OAuth2SessionFilter::new();
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 4);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session12);
assert_eq!(list.edges[2], session21);
assert_eq!(list.edges[3], session22);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
let filter = OAuth2SessionFilter::new().for_user(&user1);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 2);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session21);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
let filter = OAuth2SessionFilter::new().for_client(&client1);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 2);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session12);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
let filter = OAuth2SessionFilter::new()
.for_user(&user2)
.for_client(&client2);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session22);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
let filter = OAuth2SessionFilter::new().active_only();
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 2);
assert_eq!(list.edges[0], session12);
assert_eq!(list.edges[1], session21);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
let filter = OAuth2SessionFilter::new().finished_only();
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 2);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session22);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
let filter = OAuth2SessionFilter::new().finished_only().for_user(&user2);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session22);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
let filter = OAuth2SessionFilter::new()
.finished_only()
.for_client(&client2);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session22);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
let filter = OAuth2SessionFilter::new().active_only().for_user(&user2);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session12);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
let filter = OAuth2SessionFilter::new()
.active_only()
.for_client(&client2);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session21);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
let scope = Scope::from_iter([OPENID]);
let filter = OAuth2SessionFilter::new().with_scope(&scope);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 4);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session12);
assert_eq!(list.edges[2], session21);
assert_eq!(list.edges[3], session22);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
let scope = Scope::from_iter([OPENID, EMAIL]);
let filter = OAuth2SessionFilter::new().with_scope(&scope);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 2);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session12);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
let filter = OAuth2SessionFilter::new()
.with_scope(&scope)
.for_user(&user1);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session11);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
let affected = repo
.oauth2_session()
.finish_bulk(
&clock,
OAuth2SessionFilter::new()
.for_client(&client1)
.active_only(),
)
.await
.unwrap();
assert_eq!(affected, 1);
assert_eq!(
repo.oauth2_session()
.count(OAuth2SessionFilter::new().finished_only())
.await
.unwrap(),
3
);
assert_eq!(
repo.oauth2_session()
.count(OAuth2SessionFilter::new().active_only())
.await
.unwrap(),
1
);
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_device_code_grant_repository(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
vec!["https://example.com/redirect".parse().unwrap()],
None,
None,
vec![GrantType::AuthorizationCode],
Some("Example".to_owned()),
Some("https://example.com/logo.png".parse().unwrap()),
Some("https://example.com/".parse().unwrap()),
Some("https://example.com/policy".parse().unwrap()),
Some("https://example.com/tos".parse().unwrap()),
Some("https://example.com/jwks.json".parse().unwrap()),
None,
None,
None,
None,
None,
Some("https://example.com/login".parse().unwrap()),
)
.await
.unwrap();
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let browser_session = repo
.browser_session()
.add(&mut rng, &clock, &user, None)
.await
.unwrap();
let user_code = "usercode";
let device_code = "devicecode";
let scope = Scope::from_iter([OPENID, EMAIL]);
let grant = repo
.oauth2_device_code_grant()
.add(
&mut rng,
&clock,
OAuth2DeviceCodeGrantParams {
client: &client,
scope: scope.clone(),
device_code: device_code.to_owned(),
user_code: user_code.to_owned(),
expires_in: Duration::try_minutes(5).unwrap(),
ip_address: None,
user_agent: None,
},
)
.await
.unwrap();
assert!(grant.is_pending());
let id = grant.id;
let lookup = repo.oauth2_device_code_grant().lookup(id).await.unwrap();
assert_eq!(lookup.as_ref(), Some(&grant));
let lookup = repo
.oauth2_device_code_grant()
.find_by_device_code(device_code)
.await
.unwrap();
assert_eq!(lookup.as_ref(), Some(&grant));
let lookup = repo
.oauth2_device_code_grant()
.find_by_user_code(user_code)
.await
.unwrap();
assert_eq!(lookup.as_ref(), Some(&grant));
let grant = repo
.oauth2_device_code_grant()
.fulfill(&clock, grant, &browser_session)
.await
.unwrap();
assert!(!grant.is_pending());
assert!(grant.is_fulfilled());
let res = repo
.oauth2_device_code_grant()
.reject(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
let res = repo
.oauth2_device_code_grant()
.fulfill(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
let session = repo
.oauth2_session()
.add_from_browser_session(&mut rng, &clock, &client, &browser_session, scope.clone())
.await
.unwrap();
let grant = repo
.oauth2_device_code_grant()
.exchange(&clock, grant, &session)
.await
.unwrap();
assert!(!grant.is_pending());
assert!(!grant.is_fulfilled());
assert!(grant.is_exchanged());
let res = repo
.oauth2_device_code_grant()
.exchange(&clock, grant, &session)
.await;
assert!(res.is_err());
let grant = repo
.oauth2_device_code_grant()
.add(
&mut rng,
&clock,
OAuth2DeviceCodeGrantParams {
client: &client,
scope: scope.clone(),
device_code: "second_devicecode".to_owned(),
user_code: "second_usercode".to_owned(),
expires_in: Duration::try_minutes(5).unwrap(),
ip_address: None,
user_agent: None,
},
)
.await
.unwrap();
let id = grant.id;
let grant = repo
.oauth2_device_code_grant()
.reject(&clock, grant, &browser_session)
.await
.unwrap();
assert!(!grant.is_pending());
assert!(grant.is_rejected());
let res = repo
.oauth2_device_code_grant()
.reject(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
let res = repo
.oauth2_device_code_grant()
.fulfill(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
let res = repo
.oauth2_device_code_grant()
.exchange(&clock, grant, &session)
.await;
assert!(res.is_err());
}
}