use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, ID, InputObject, Object};
use mas_i18n::DataLocale;
use mas_storage::{
RepositoryAccess,
queue::{ProvisionUserJob, QueueJobRepositoryExt as _, SendEmailAuthenticationCodeJob},
user::{UserEmailFilter, UserEmailRepository, UserRepository},
};
use super::verify_password_if_needed;
use crate::graphql::{
model::{NodeType, User, UserEmail, UserEmailAuthentication},
state::ContextExt,
};
#[derive(Default)]
pub struct UserEmailMutations {
_private: (),
}
#[derive(InputObject)]
struct AddEmailInput {
email: String,
user_id: ID,
skip_verification: Option<bool>,
skip_policy_check: Option<bool>,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum AddEmailStatus {
Added,
Exists,
Invalid,
Denied,
}
#[derive(Description)]
enum AddEmailPayload {
Added(mas_data_model::UserEmail),
Exists(mas_data_model::UserEmail),
Invalid,
Denied {
violations: Vec<mas_policy::Violation>,
},
}
#[Object(use_type_description)]
impl AddEmailPayload {
async fn status(&self) -> AddEmailStatus {
match self {
AddEmailPayload::Added(_) => AddEmailStatus::Added,
AddEmailPayload::Exists(_) => AddEmailStatus::Exists,
AddEmailPayload::Invalid => AddEmailStatus::Invalid,
AddEmailPayload::Denied { .. } => AddEmailStatus::Denied,
}
}
async fn email(&self) -> Option<UserEmail> {
match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => {
Some(UserEmail(email.clone()))
}
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None,
}
}
async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user_id = match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id,
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None),
};
let user = repo
.user()
.lookup(user_id)
.await?
.context("User not found")?;
Ok(Some(User(user)))
}
async fn violations(&self) -> Option<Vec<String>> {
let AddEmailPayload::Denied { violations } = self else {
return None;
};
let messages = violations.iter().map(|v| v.msg.clone()).collect();
Some(messages)
}
}
#[derive(InputObject)]
struct RemoveEmailInput {
user_email_id: ID,
password: Option<String>,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum RemoveEmailStatus {
Removed,
NotFound,
IncorrectPassword,
}
#[derive(Description)]
enum RemoveEmailPayload {
Removed(mas_data_model::UserEmail),
NotFound,
IncorrectPassword,
}
#[Object(use_type_description)]
impl RemoveEmailPayload {
async fn status(&self) -> RemoveEmailStatus {
match self {
RemoveEmailPayload::Removed(_) => RemoveEmailStatus::Removed,
RemoveEmailPayload::NotFound => RemoveEmailStatus::NotFound,
RemoveEmailPayload::IncorrectPassword => RemoveEmailStatus::IncorrectPassword,
}
}
async fn email(&self) -> Option<UserEmail> {
match self {
RemoveEmailPayload::Removed(email) => Some(UserEmail(email.clone())),
RemoveEmailPayload::NotFound | RemoveEmailPayload::IncorrectPassword => None,
}
}
async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let user_id = match self {
RemoveEmailPayload::Removed(email) => email.user_id,
RemoveEmailPayload::NotFound | RemoveEmailPayload::IncorrectPassword => {
return Ok(None);
}
};
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(user_id)
.await?
.context("User not found")?;
Ok(Some(User(user)))
}
}
#[derive(InputObject)]
struct SetPrimaryEmailInput {
user_email_id: ID,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum SetPrimaryEmailStatus {
Set,
NotFound,
Unverified,
}
#[derive(Description)]
enum SetPrimaryEmailPayload {
Set(mas_data_model::User),
NotFound,
}
#[Object(use_type_description)]
impl SetPrimaryEmailPayload {
async fn status(&self) -> SetPrimaryEmailStatus {
match self {
SetPrimaryEmailPayload::Set(_) => SetPrimaryEmailStatus::Set,
SetPrimaryEmailPayload::NotFound => SetPrimaryEmailStatus::NotFound,
}
}
async fn user(&self) -> Option<User> {
match self {
SetPrimaryEmailPayload::Set(user) => Some(User(user.clone())),
SetPrimaryEmailPayload::NotFound => None,
}
}
}
#[derive(InputObject)]
struct StartEmailAuthenticationInput {
email: String,
password: Option<String>,
#[graphql(default = "en")]
language: String,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum StartEmailAuthenticationStatus {
Started,
InvalidEmailAddress,
RateLimited,
Denied,
InUse,
IncorrectPassword,
}
#[derive(Description)]
enum StartEmailAuthenticationPayload {
Started(UserEmailAuthentication),
InvalidEmailAddress,
RateLimited,
Denied {
violations: Vec<mas_policy::Violation>,
},
InUse,
IncorrectPassword,
}
#[Object(use_type_description)]
impl StartEmailAuthenticationPayload {
async fn status(&self) -> StartEmailAuthenticationStatus {
match self {
Self::Started(_) => StartEmailAuthenticationStatus::Started,
Self::InvalidEmailAddress => StartEmailAuthenticationStatus::InvalidEmailAddress,
Self::RateLimited => StartEmailAuthenticationStatus::RateLimited,
Self::Denied { .. } => StartEmailAuthenticationStatus::Denied,
Self::InUse => StartEmailAuthenticationStatus::InUse,
Self::IncorrectPassword => StartEmailAuthenticationStatus::IncorrectPassword,
}
}
async fn authentication(&self) -> Option<&UserEmailAuthentication> {
match self {
Self::Started(authentication) => Some(authentication),
Self::InvalidEmailAddress
| Self::RateLimited
| Self::Denied { .. }
| Self::InUse
| Self::IncorrectPassword => None,
}
}
async fn violations(&self) -> Option<Vec<String>> {
let Self::Denied { violations } = self else {
return None;
};
let messages = violations.iter().map(|v| v.msg.clone()).collect();
Some(messages)
}
}
#[derive(InputObject)]
struct CompleteEmailAuthenticationInput {
code: String,
id: ID,
}
#[derive(Description)]
enum CompleteEmailAuthenticationPayload {
Completed,
InvalidCode,
CodeExpired,
InUse,
RateLimited,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum CompleteEmailAuthenticationStatus {
Completed,
InvalidCode,
CodeExpired,
RateLimited,
InUse,
}
#[Object(use_type_description)]
impl CompleteEmailAuthenticationPayload {
async fn status(&self) -> CompleteEmailAuthenticationStatus {
match self {
Self::Completed => CompleteEmailAuthenticationStatus::Completed,
Self::InvalidCode => CompleteEmailAuthenticationStatus::InvalidCode,
Self::CodeExpired => CompleteEmailAuthenticationStatus::CodeExpired,
Self::InUse => CompleteEmailAuthenticationStatus::InUse,
Self::RateLimited => CompleteEmailAuthenticationStatus::RateLimited,
}
}
}
#[derive(InputObject)]
struct ResendEmailAuthenticationCodeInput {
id: ID,
#[graphql(default = "en")]
language: String,
}
#[derive(Description)]
enum ResendEmailAuthenticationCodePayload {
Resent,
Completed,
RateLimited,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum ResendEmailAuthenticationCodeStatus {
Resent,
Completed,
RateLimited,
}
#[Object(use_type_description)]
impl ResendEmailAuthenticationCodePayload {
async fn status(&self) -> ResendEmailAuthenticationCodeStatus {
match self {
Self::Resent => ResendEmailAuthenticationCodeStatus::Resent,
Self::Completed => ResendEmailAuthenticationCodeStatus::Completed,
Self::RateLimited => ResendEmailAuthenticationCodeStatus::RateLimited,
}
}
}
#[Object]
impl UserEmailMutations {
#[graphql(deprecation = "Use `startEmailAuthentication` instead.")]
async fn add_email(
&self,
ctx: &Context<'_>,
input: AddEmailInput,
) -> Result<AddEmailPayload, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
let clock = state.clock();
let mut rng = state.rng();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let _skip_verification = input.skip_verification.unwrap_or(false);
let skip_policy_check = input.skip_policy_check.unwrap_or(false);
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(id)
.await?
.context("Failed to load user")?;
if input.email.parse::<lettre::Address>().is_err() {
return Ok(AddEmailPayload::Invalid);
}
if !skip_policy_check {
let mut policy = state.policy().await?;
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester.for_policy(),
})
.await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
});
}
}
let existing_user_email = repo.user_email().find(&user, &input.email).await?;
let (added, user_email) = if let Some(user_email) = existing_user_email {
(false, user_email)
} else {
let user_email = repo
.user_email()
.add(&mut rng, &clock, &user, input.email)
.await?;
(true, user_email)
};
repo.save().await?;
let payload = if added {
AddEmailPayload::Added(user_email)
} else {
AddEmailPayload::Exists(user_email)
};
Ok(payload)
}
async fn remove_email(
&self,
ctx: &Context<'_>,
input: RemoveEmailInput,
) -> Result<RemoveEmailPayload, async_graphql::Error> {
let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?;
let requester = ctx.requester();
let mut rng = state.rng();
let clock = state.clock();
let mut repo = state.repository().await?;
let user_email = repo.user_email().lookup(user_email_id).await?;
let Some(user_email) = user_email else {
return Ok(RemoveEmailPayload::NotFound);
};
if !requester.is_owner_or_admin(&user_email) {
return Ok(RemoveEmailPayload::NotFound);
}
if !requester.is_admin() && !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new("Unauthorized"));
}
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Failed to load user")?;
if !verify_password_if_needed(
requester,
state.site_config(),
&state.password_manager(),
input.password,
&user,
&mut repo,
)
.await?
{
return Ok(RemoveEmailPayload::IncorrectPassword);
}
repo.user_email().remove(user_email.clone()).await?;
repo.queue_job()
.schedule_job(&mut rng, &clock, ProvisionUserJob::new(&user))
.await?;
repo.save().await?;
Ok(RemoveEmailPayload::Removed(user_email))
}
#[graphql(
deprecation = "This doesn't do anything anymore, but is kept to avoid breaking existing queries"
)]
async fn set_primary_email(
&self,
ctx: &Context<'_>,
input: SetPrimaryEmailInput,
) -> Result<SetPrimaryEmailPayload, async_graphql::Error> {
let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let user_email = repo.user_email().lookup(user_email_id).await?;
let Some(user_email) = user_email else {
return Ok(SetPrimaryEmailPayload::NotFound);
};
if !requester.is_owner_or_admin(&user_email) {
return Err(async_graphql::Error::new("Unauthorized"));
}
if !requester.is_admin() && !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new("Unauthorized"));
}
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Failed to load user")?;
repo.save().await?;
Ok(SetPrimaryEmailPayload::Set(user))
}
async fn start_email_authentication(
&self,
ctx: &Context<'_>,
input: StartEmailAuthenticationInput,
) -> Result<StartEmailAuthenticationPayload, async_graphql::Error> {
let state = ctx.state();
let mut rng = state.rng();
let clock = state.clock();
let requester = ctx.requester();
let limiter = state.limiter();
let Some(browser_session) = requester.browser_session() else {
return Err(async_graphql::Error::new("Unauthorized"));
};
if !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new(
"Email changes are not allowed on this server",
));
}
if !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new(
"Email authentication is not allowed on this server",
));
}
let _: DataLocale = input.language.parse()?;
if input.email.parse::<lettre::Address>().is_err() {
return Ok(StartEmailAuthenticationPayload::InvalidEmailAddress);
}
if let Err(e) =
limiter.check_email_authentication_email(requester.fingerprint(), &input.email)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(StartEmailAuthenticationPayload::RateLimited);
}
let mut repo = state.repository().await?;
let count = repo
.user_email()
.count(
UserEmailFilter::new()
.for_email(&input.email)
.for_user(&browser_session.user),
)
.await?;
if count > 0 {
return Ok(StartEmailAuthenticationPayload::InUse);
}
let mut policy = state.policy().await?;
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester.for_policy(),
})
.await?;
if !res.valid() {
return Ok(StartEmailAuthenticationPayload::Denied {
violations: res.violations,
});
}
if !verify_password_if_needed(
requester,
state.site_config(),
&state.password_manager(),
input.password,
&browser_session.user,
&mut repo,
)
.await?
{
return Ok(StartEmailAuthenticationPayload::IncorrectPassword);
}
let authentication = repo
.user_email()
.add_authentication_for_session(&mut rng, &clock, input.email, browser_session)
.await?;
repo.queue_job()
.schedule_job(
&mut rng,
&clock,
SendEmailAuthenticationCodeJob::new(&authentication, input.language),
)
.await?;
repo.save().await?;
Ok(StartEmailAuthenticationPayload::Started(
UserEmailAuthentication(authentication),
))
}
async fn resend_email_authentication_code(
&self,
ctx: &Context<'_>,
input: ResendEmailAuthenticationCodeInput,
) -> Result<ResendEmailAuthenticationCodePayload, async_graphql::Error> {
let state = ctx.state();
let mut rng = state.rng();
let clock = state.clock();
let limiter = state.limiter();
let requester = ctx.requester();
let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?;
let Some(browser_session) = requester.browser_session() else {
return Err(async_graphql::Error::new("Unauthorized"));
};
if !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new(
"Email changes are not allowed on this server",
));
}
let _: DataLocale = input.language.parse()?;
let mut repo = state.repository().await?;
let Some(authentication) = repo.user_email().lookup_authentication(id).await? else {
return Ok(ResendEmailAuthenticationCodePayload::Completed);
};
if authentication.user_session_id != Some(browser_session.id) {
return Err(async_graphql::Error::new("Unauthorized"));
}
if authentication.completed_at.is_some() {
return Ok(ResendEmailAuthenticationCodePayload::Completed);
}
if let Err(e) =
limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(ResendEmailAuthenticationCodePayload::RateLimited);
}
repo.queue_job()
.schedule_job(
&mut rng,
&clock,
SendEmailAuthenticationCodeJob::new(&authentication, input.language),
)
.await?;
repo.save().await?;
Ok(ResendEmailAuthenticationCodePayload::Resent)
}
async fn complete_email_authentication(
&self,
ctx: &Context<'_>,
input: CompleteEmailAuthenticationInput,
) -> Result<CompleteEmailAuthenticationPayload, async_graphql::Error> {
let state = ctx.state();
let mut rng = state.rng();
let clock = state.clock();
let limiter = state.limiter();
let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?;
let Some(browser_session) = ctx.requester().browser_session() else {
return Err(async_graphql::Error::new("Unauthorized"));
};
if !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new(
"Email changes are not allowed on this server",
));
}
let mut repo = state.repository().await?;
let Some(authentication) = repo.user_email().lookup_authentication(id).await? else {
return Ok(CompleteEmailAuthenticationPayload::InvalidCode);
};
if authentication.user_session_id != Some(browser_session.id) {
return Ok(CompleteEmailAuthenticationPayload::InvalidCode);
}
if let Err(e) = limiter.check_email_authentication_attempt(&authentication) {
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(CompleteEmailAuthenticationPayload::RateLimited);
}
let Some(code) = repo
.user_email()
.find_authentication_code(&authentication, &input.code)
.await?
else {
return Ok(CompleteEmailAuthenticationPayload::InvalidCode);
};
if code.expires_at < state.clock().now() {
return Ok(CompleteEmailAuthenticationPayload::CodeExpired);
}
let authentication = repo
.user_email()
.complete_authentication(&clock, authentication, &code)
.await?;
let count = repo
.user_email()
.count(UserEmailFilter::new().for_email(&authentication.email))
.await?;
if count > 0 {
repo.save().await?;
return Ok(CompleteEmailAuthenticationPayload::InUse);
}
repo.user_email()
.add(
&mut rng,
&clock,
&browser_session.user,
authentication.email,
)
.await?;
repo.save().await?;
Ok(CompleteEmailAuthenticationPayload::Completed)
}
}