use async_trait::async_trait;
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
use mas_storage::{
app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
compat::CompatSessionFilter,
oauth2::OAuth2SessionFilter,
Page, Pagination,
};
use oauth2_types::scope::{Scope, ScopeToken};
use sea_query::{
Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
use crate::{
errors::DatabaseInconsistencyError,
filter::StatementExt,
iden::{CompatSessions, OAuth2Sessions},
pagination::QueryBuilderExt,
DatabaseError, ExecuteExt,
};
pub struct PgAppSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgAppSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
mod priv_ {
use std::net::IpAddr;
use chrono::{DateTime, Utc};
use sea_query::enum_def;
use uuid::Uuid;
#[derive(sqlx::FromRow)]
#[enum_def]
pub(super) struct AppSessionLookup {
pub(super) cursor: Uuid,
pub(super) compat_session_id: Option<Uuid>,
pub(super) oauth2_session_id: Option<Uuid>,
pub(super) oauth2_client_id: Option<Uuid>,
pub(super) user_session_id: Option<Uuid>,
pub(super) user_id: Option<Uuid>,
pub(super) scope_list: Option<Vec<String>>,
pub(super) device_id: Option<String>,
pub(super) created_at: DateTime<Utc>,
pub(super) finished_at: Option<DateTime<Utc>>,
pub(super) is_synapse_admin: Option<bool>,
pub(super) user_agent: Option<String>,
pub(super) last_active_at: Option<DateTime<Utc>>,
pub(super) last_active_ip: Option<IpAddr>,
}
}
use priv_::{AppSessionLookup, AppSessionLookupIden};
impl TryFrom<AppSessionLookup> for AppSession {
type Error = DatabaseError;
#[allow(clippy::too_many_lines)]
fn try_from(value: AppSessionLookup) -> Result<Self, Self::Error> {
let AppSessionLookup {
cursor,
compat_session_id,
oauth2_session_id,
oauth2_client_id,
user_session_id,
user_id,
scope_list,
device_id,
created_at,
finished_at,
is_synapse_admin,
user_agent,
last_active_at,
last_active_ip,
} = value;
let user_agent = user_agent.map(UserAgent::parse);
let user_session_id = user_session_id.map(Ulid::from);
match (
compat_session_id,
oauth2_session_id,
oauth2_client_id,
user_id,
scope_list,
device_id,
is_synapse_admin,
) {
(
Some(compat_session_id),
None,
None,
Some(user_id),
None,
Some(device_id),
Some(is_synapse_admin),
) => {
let id = compat_session_id.into();
let device = Device::try_from(device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: user_id.into(),
device,
user_session_id,
created_at,
is_synapse_admin,
user_agent,
last_active_at,
last_active_ip,
};
Ok(AppSession::Compat(Box::new(session)))
}
(
None,
Some(oauth2_session_id),
Some(oauth2_client_id),
user_id,
Some(scope_list),
None,
None,
) => {
let id = oauth2_session_id.into();
let scope: Result<Scope, _> =
scope_list.iter().map(|s| s.parse::<ScopeToken>()).collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => SessionState::Valid,
Some(finished_at) => SessionState::Finished { finished_at },
};
let session = Session {
id,
state,
created_at,
client_id: oauth2_client_id.into(),
user_id: user_id.map(Ulid::from),
user_session_id,
scope,
user_agent,
last_active_at,
last_active_ip,
};
Ok(AppSession::OAuth2(Box::new(session)))
}
_ => Err(DatabaseInconsistencyError::on("sessions")
.row(cursor.into())
.into()),
}
}
}
fn split_filter(
filter: AppSessionFilter<'_>,
) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
let mut compat_filter = CompatSessionFilter::new();
let mut oauth2_filter = OAuth2SessionFilter::new();
if let Some(user) = filter.user() {
compat_filter = compat_filter.for_user(user);
oauth2_filter = oauth2_filter.for_user(user);
}
match filter.state() {
Some(AppSessionState::Active) => {
compat_filter = compat_filter.active_only();
oauth2_filter = oauth2_filter.active_only();
}
Some(AppSessionState::Finished) => {
compat_filter = compat_filter.finished_only();
oauth2_filter = oauth2_filter.finished_only();
}
None => {}
}
if let Some(device) = filter.device() {
compat_filter = compat_filter.for_device(device);
oauth2_filter = oauth2_filter.for_device(device);
}
if let Some(browser_session) = filter.browser_session() {
compat_filter = compat_filter.for_browser_session(browser_session);
oauth2_filter = oauth2_filter.for_browser_session(browser_session);
}
if let Some(last_active_before) = filter.last_active_before() {
compat_filter = compat_filter.with_last_active_before(last_active_before);
oauth2_filter = oauth2_filter.with_last_active_before(last_active_before);
}
if let Some(last_active_after) = filter.last_active_after() {
compat_filter = compat_filter.with_last_active_after(last_active_after);
oauth2_filter = oauth2_filter.with_last_active_after(last_active_after);
}
(compat_filter, oauth2_filter)
}
#[async_trait]
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
type Error = DatabaseError;
#[allow(clippy::too_many_lines)]
#[tracing::instrument(
name = "db.app_session.list",
fields(
db.query.text,
),
skip_all,
err,
)]
async fn list(
&mut self,
filter: AppSessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<AppSession>, Self::Error> {
let (compat_filter, oauth2_filter) = split_filter(filter);
let mut oauth2_session_select = Query::select()
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
AppSessionLookupIden::Cursor,
)
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::CompatSessionId)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
AppSessionLookupIden::Oauth2SessionId,
)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
AppSessionLookupIden::Oauth2ClientId,
)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
AppSessionLookupIden::UserSessionId,
)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
AppSessionLookupIden::UserId,
)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
AppSessionLookupIden::ScopeList,
)
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::DeviceId)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
AppSessionLookupIden::CreatedAt,
)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
AppSessionLookupIden::FinishedAt,
)
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::IsSynapseAdmin)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
AppSessionLookupIden::UserAgent,
)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
AppSessionLookupIden::LastActiveAt,
)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
AppSessionLookupIden::LastActiveIp,
)
.from(OAuth2Sessions::Table)
.apply_filter(oauth2_filter)
.clone();
let compat_session_select = Query::select()
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
AppSessionLookupIden::Cursor,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
AppSessionLookupIden::CompatSessionId,
)
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2SessionId)
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2ClientId)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
AppSessionLookupIden::UserSessionId,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::UserId)),
AppSessionLookupIden::UserId,
)
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::ScopeList)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
AppSessionLookupIden::DeviceId,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
AppSessionLookupIden::CreatedAt,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
AppSessionLookupIden::FinishedAt,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
AppSessionLookupIden::IsSynapseAdmin,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
AppSessionLookupIden::UserAgent,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
AppSessionLookupIden::LastActiveAt,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
AppSessionLookupIden::LastActiveIp,
)
.from(CompatSessions::Table)
.apply_filter(compat_filter)
.clone();
let common_table_expression = CommonTableExpression::new()
.query(
oauth2_session_select
.union(UnionType::All, compat_session_select)
.clone(),
)
.table_name(Alias::new("sessions"))
.clone();
let with_clause = Query::with().cte(common_table_expression).clone();
let select = Query::select()
.column(ColumnRef::Asterisk)
.from(Alias::new("sessions"))
.generate_pagination(AppSessionLookupIden::Cursor, pagination)
.clone();
let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
let edges: Vec<AppSessionLookup> = sqlx::query_as_with(&sql, arguments)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).try_map(TryFrom::try_from)?;
Ok(page)
}
#[tracing::instrument(
name = "db.app_session.count",
fields(
db.query.text,
),
skip_all,
err,
)]
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
let (compat_filter, oauth2_filter) = split_filter(filter);
let mut oauth2_session_select = Query::select()
.expr(Expr::cust("1"))
.from(OAuth2Sessions::Table)
.apply_filter(oauth2_filter)
.clone();
let compat_session_select = Query::select()
.expr(Expr::cust("1"))
.from(CompatSessions::Table)
.apply_filter(compat_filter)
.clone();
let common_table_expression = CommonTableExpression::new()
.query(
oauth2_session_select
.union(UnionType::All, compat_session_select)
.clone(),
)
.table_name(Alias::new("sessions"))
.clone();
let with_clause = Query::with().cte(common_table_expression).clone();
let select = Query::select()
.expr(Expr::cust("COUNT(*)"))
.from(Alias::new("sessions"))
.clone();
let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::Device;
use mas_storage::{
app_session::{AppSession, AppSessionFilter},
clock::MockClock,
oauth2::OAuth2SessionRepository,
Pagination, RepositoryAccess,
};
use oauth2_types::{
requests::GrantType,
scope::{Scope, OPENID},
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_app_repo(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let all = AppSessionFilter::new().for_user(&user);
let active = all.active_only();
let finished = all.finished_only();
let pagination = Pagination::first(10);
assert_eq!(repo.app_session().count(all).await.unwrap(), 0);
assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
let full_list = repo.app_session().list(all, pagination).await.unwrap();
assert!(full_list.edges.is_empty());
let active_list = repo.app_session().list(active, pagination).await.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
assert!(finished_list.edges.is_empty());
let device = Device::generate(&mut rng);
let compat_session = repo
.compat_session()
.add(&mut rng, &clock, &user, device.clone(), None, false)
.await
.unwrap();
assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
let full_list = repo.app_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(
full_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
let active_list = repo.app_session().list(active, pagination).await.unwrap();
assert_eq!(active_list.edges.len(), 1);
assert_eq!(
active_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
assert!(finished_list.edges.is_empty());
let compat_session = repo
.compat_session()
.finish(&clock, compat_session)
.await
.unwrap();
assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
let full_list = repo.app_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(
full_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
let active_list = repo.app_session().list(active, pagination).await.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
assert_eq!(finished_list.edges.len(), 1);
assert_eq!(
finished_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
vec!["https://example.com/redirect".parse().unwrap()],
None,
None,
vec![GrantType::AuthorizationCode],
Some("First client".to_owned()),
Some("https://example.com/logo.png".parse().unwrap()),
Some("https://example.com/".parse().unwrap()),
Some("https://example.com/policy".parse().unwrap()),
Some("https://example.com/tos".parse().unwrap()),
Some("https://example.com/jwks.json".parse().unwrap()),
None,
None,
None,
None,
None,
Some("https://example.com/login".parse().unwrap()),
)
.await
.unwrap();
let device2 = Device::generate(&mut rng);
let scope = Scope::from_iter([OPENID, device2.to_scope_token()]);
clock.advance(Duration::try_minutes(1).unwrap());
let oauth_session = repo
.oauth2_session()
.add(&mut rng, &clock, &client, Some(&user), None, scope)
.await
.unwrap();
assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
let full_list = repo.app_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 2);
assert_eq!(
full_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
assert_eq!(
full_list.edges[1],
AppSession::OAuth2(Box::new(oauth_session.clone()))
);
let active_list = repo.app_session().list(active, pagination).await.unwrap();
assert_eq!(active_list.edges.len(), 1);
assert_eq!(
active_list.edges[0],
AppSession::OAuth2(Box::new(oauth_session.clone()))
);
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
assert_eq!(finished_list.edges.len(), 1);
assert_eq!(
finished_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
let oauth_session = repo
.oauth2_session()
.finish(&clock, oauth_session)
.await
.unwrap();
assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
assert_eq!(repo.app_session().count(finished).await.unwrap(), 2);
let full_list = repo.app_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 2);
assert_eq!(
full_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
assert_eq!(
full_list.edges[1],
AppSession::OAuth2(Box::new(oauth_session.clone()))
);
let active_list = repo.app_session().list(active, pagination).await.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
assert_eq!(finished_list.edges.len(), 2);
assert_eq!(
finished_list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
assert_eq!(
full_list.edges[1],
AppSession::OAuth2(Box::new(oauth_session.clone()))
);
let filter = AppSessionFilter::new().for_device(&device);
assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
let list = repo.app_session().list(filter, pagination).await.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(
list.edges[0],
AppSession::Compat(Box::new(compat_session.clone()))
);
let filter = AppSessionFilter::new().for_device(&device2);
assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
let list = repo.app_session().list(filter, pagination).await.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(
list.edges[0],
AppSession::OAuth2(Box::new(oauth_session.clone()))
);
let user2 = repo
.user()
.add(&mut rng, &clock, "alice".to_owned())
.await
.unwrap();
let filter = AppSessionFilter::new().for_user(&user2);
assert_eq!(repo.app_session().count(filter).await.unwrap(), 0);
let list = repo.app_session().list(filter, pagination).await.unwrap();
assert!(list.edges.is_empty());
}
}