use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
use mas_storage::{
upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
Clock, Page, Pagination,
};
use rand::RngCore;
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
filter::{Filter, StatementExt},
iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
pagination::QueryBuilderExt,
tracing::ExecuteExt,
DatabaseError,
};
pub struct PgUpstreamOAuthLinkRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthLinkRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
#[enum_def]
struct LinkLookup {
upstream_oauth_link_id: Uuid,
upstream_oauth_provider_id: Uuid,
user_id: Option<Uuid>,
subject: String,
created_at: DateTime<Utc>,
}
impl From<LinkLookup> for UpstreamOAuthLink {
fn from(value: LinkLookup) -> Self {
UpstreamOAuthLink {
id: Ulid::from(value.upstream_oauth_link_id),
provider_id: Ulid::from(value.upstream_oauth_provider_id),
user_id: value.user_id.map(Ulid::from),
subject: value.subject,
created_at: value.created_at,
}
}
}
impl Filter for UpstreamOAuthLinkFilter<'_> {
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
sea_query::Condition::all()
.add_option(self.user().map(|user| {
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
.eq(Uuid::from(user.id))
}))
.add_option(self.provider().map(|provider| {
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
))
.eq(Uuid::from(provider.id))
}))
.add_option(self.provider_enabled().map(|enabled| {
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
))
.eq(Expr::any(
Query::select()
.expr(Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::UpstreamOAuthProviderId,
)))
.from(UpstreamOAuthProviders::Table)
.and_where(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled),
)
.take(),
))
}))
}
}
#[async_trait]
impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_link.lookup",
skip_all,
fields(
db.query.text,
upstream_oauth_link.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_link_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_optional(&mut *self.conn)
.await?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.find_by_subject",
skip_all,
fields(
db.query.text,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn find_by_subject(
&mut self,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: &str,
) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
AND subject = $2
"#,
Uuid::from(upstream_oauth_provider.id),
subject,
)
.traced()
.fetch_optional(&mut *self.conn)
.await?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.add",
skip_all,
fields(
db.query.text,
upstream_oauth_link.id,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
) -> Result<UpstreamOAuthLink, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_links (
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
) VALUES ($1, $2, NULL, $3, $4)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&subject,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthLink {
id,
provider_id: upstream_oauth_provider.id,
user_id: None,
subject,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_link.associate_to_user",
skip_all,
fields(
db.query.text,
%upstream_oauth_link.id,
%upstream_oauth_link.subject,
%user.id,
%user.username,
),
err,
)]
async fn associate_to_user(
&mut self,
upstream_oauth_link: &UpstreamOAuthLink,
user: &User,
) -> Result<(), Self::Error> {
sqlx::query!(
r#"
UPDATE upstream_oauth_links
SET user_id = $1
WHERE upstream_oauth_link_id = $2
"#,
Uuid::from(user.id),
Uuid::from(upstream_oauth_link.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.upstream_oauth_link.list",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn list(
&mut self,
filter: UpstreamOAuthLinkFilter<'_>,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
let (sql, arguments) = Query::select()
.expr_as(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthLinkId,
)),
LinkLookupIden::UpstreamOauthLinkId,
)
.expr_as(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
)),
LinkLookupIden::UpstreamOauthProviderId,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
LinkLookupIden::UserId,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
LinkLookupIden::Subject,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
LinkLookupIden::CreatedAt,
)
.from(UpstreamOAuthLinks::Table)
.apply_filter(filter)
.generate_pagination(
(
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthLinkId,
),
pagination,
)
.build_sqlx(PostgresQueryBuilder);
let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
Ok(page)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.count",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
let (sql, arguments) = Query::select()
.expr(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthLinkId,
))
.count(),
)
.from(UpstreamOAuthLinks::Table)
.apply_filter(filter)
.build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
}