mod branding;
mod captcha;
mod ext;
mod features;
use std::{
fmt::Formatter,
net::{IpAddr, Ipv4Addr},
};
use chrono::{DateTime, Duration, Utc};
use http::{Method, Uri, Version};
use mas_data_model::{
AuthorizationGrant, BrowserSession, Client, CompatSsoLogin, CompatSsoLoginState,
DeviceCodeGrant, UpstreamOAuthLink, UpstreamOAuthProvider, User, UserAgent, UserEmail,
UserEmailVerification, UserRecoverySession,
};
use mas_i18n::DataLocale;
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
use oauth2_types::scope::OPENID;
use rand::{
distributions::{Alphanumeric, DistString},
Rng,
};
use serde::{ser::SerializeStruct, Deserialize, Serialize};
use ulid::Ulid;
use url::Url;
pub use self::{
branding::SiteBranding, captcha::WithCaptcha, ext::SiteConfigExt, features::SiteFeatures,
};
use crate::{FieldError, FormField, FormState};
pub trait TemplateContext: Serialize {
fn with_session(self, current_session: BrowserSession) -> WithSession<Self>
where
Self: Sized,
{
WithSession {
current_session,
inner: self,
}
}
fn maybe_with_session(
self,
current_session: Option<BrowserSession>,
) -> WithOptionalSession<Self>
where
Self: Sized,
{
WithOptionalSession {
current_session,
inner: self,
}
}
fn with_csrf<C>(self, csrf_token: C) -> WithCsrf<Self>
where
Self: Sized,
C: ToString,
{
WithCsrf {
csrf_token: csrf_token.to_string(),
inner: self,
}
}
fn with_language(self, lang: DataLocale) -> WithLanguage<Self>
where
Self: Sized,
{
WithLanguage {
lang: lang.to_string(),
inner: self,
}
}
fn with_captcha(self, captcha: Option<mas_data_model::CaptchaConfig>) -> WithCaptcha<Self>
where
Self: Sized,
{
WithCaptcha::new(captcha, self)
}
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized;
}
impl TemplateContext for () {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
Vec::new()
}
}
#[derive(Serialize, Debug)]
pub struct WithLanguage<T> {
lang: String,
#[serde(flatten)]
inner: T,
}
impl<T> WithLanguage<T> {
pub fn language(&self) -> &str {
&self.lang
}
}
impl<T> std::ops::Deref for WithLanguage<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T: TemplateContext> TemplateContext for WithLanguage<T> {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
T::sample(now, rng)
.into_iter()
.map(|inner| WithLanguage {
lang: "en".into(),
inner,
})
.collect()
}
}
#[derive(Serialize, Debug)]
pub struct WithCsrf<T> {
csrf_token: String,
#[serde(flatten)]
inner: T,
}
impl<T: TemplateContext> TemplateContext for WithCsrf<T> {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
T::sample(now, rng)
.into_iter()
.map(|inner| WithCsrf {
csrf_token: "fake_csrf_token".into(),
inner,
})
.collect()
}
}
#[derive(Serialize)]
pub struct WithSession<T> {
current_session: BrowserSession,
#[serde(flatten)]
inner: T,
}
impl<T: TemplateContext> TemplateContext for WithSession<T> {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
BrowserSession::samples(now, rng)
.into_iter()
.flat_map(|session| {
T::sample(now, rng)
.into_iter()
.map(move |inner| WithSession {
current_session: session.clone(),
inner,
})
})
.collect()
}
}
#[derive(Serialize)]
pub struct WithOptionalSession<T> {
current_session: Option<BrowserSession>,
#[serde(flatten)]
inner: T,
}
impl<T: TemplateContext> TemplateContext for WithOptionalSession<T> {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
BrowserSession::samples(now, rng)
.into_iter()
.map(Some) .chain(std::iter::once(None)) .flat_map(|session| {
T::sample(now, rng)
.into_iter()
.map(move |inner| WithOptionalSession {
current_session: session.clone(),
inner,
})
})
.collect()
}
}
pub struct EmptyContext;
impl Serialize for EmptyContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut s = serializer.serialize_struct("EmptyContext", 0)?;
s.serialize_field("__UNUSED", &())?;
s.end()
}
}
impl TemplateContext for EmptyContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![EmptyContext]
}
}
#[derive(Serialize)]
pub struct IndexContext {
discovery_url: Url,
}
impl IndexContext {
#[must_use]
pub fn new(discovery_url: Url) -> Self {
Self { discovery_url }
}
}
impl TemplateContext for IndexContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![Self {
discovery_url: "https://example.com/.well-known/openid-configuration"
.parse()
.unwrap(),
}]
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AppConfig {
root: String,
graphql_endpoint: String,
}
#[derive(Serialize)]
pub struct AppContext {
app_config: AppConfig,
}
impl AppContext {
#[must_use]
pub fn from_url_builder(url_builder: &UrlBuilder) -> Self {
let root = url_builder.relative_url_for(&Account::default());
let graphql_endpoint = url_builder.relative_url_for(&GraphQL);
Self {
app_config: AppConfig {
root,
graphql_endpoint,
},
}
}
}
impl TemplateContext for AppContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None);
vec![Self::from_url_builder(&url_builder)]
}
}
#[derive(Serialize)]
pub struct ApiDocContext {
openapi_url: Url,
callback_url: Url,
}
impl ApiDocContext {
#[must_use]
pub fn from_url_builder(url_builder: &UrlBuilder) -> Self {
Self {
openapi_url: url_builder.absolute_url_for(&mas_router::ApiSpec),
callback_url: url_builder.absolute_url_for(&mas_router::ApiDocCallback),
}
}
}
impl TemplateContext for ApiDocContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None);
vec![Self::from_url_builder(&url_builder)]
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum LoginFormField {
Username,
Password,
}
impl FormField for LoginFormField {
fn keep(&self) -> bool {
match self {
Self::Username => true,
Self::Password => false,
}
}
}
#[derive(Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum PostAuthContextInner {
ContinueAuthorizationGrant {
grant: Box<AuthorizationGrant>,
},
ContinueDeviceCodeGrant {
grant: Box<DeviceCodeGrant>,
},
ContinueCompatSsoLogin {
login: Box<CompatSsoLogin>,
},
ChangePassword,
LinkUpstream {
provider: Box<UpstreamOAuthProvider>,
link: Box<UpstreamOAuthLink>,
},
ManageAccount,
}
#[derive(Serialize)]
pub struct PostAuthContext {
pub params: PostAuthAction,
#[serde(flatten)]
pub ctx: PostAuthContextInner,
}
#[derive(Serialize, Default)]
pub struct LoginContext {
form: FormState<LoginFormField>,
next: Option<PostAuthContext>,
providers: Vec<UpstreamOAuthProvider>,
}
impl TemplateContext for LoginContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![
LoginContext {
form: FormState::default(),
next: None,
providers: Vec::new(),
},
LoginContext {
form: FormState::default(),
next: None,
providers: Vec::new(),
},
LoginContext {
form: FormState::default()
.with_error_on_field(LoginFormField::Username, FieldError::Required)
.with_error_on_field(
LoginFormField::Password,
FieldError::Policy {
message: "password too short".to_owned(),
},
),
next: None,
providers: Vec::new(),
},
LoginContext {
form: FormState::default()
.with_error_on_field(LoginFormField::Username, FieldError::Exists),
next: None,
providers: Vec::new(),
},
]
}
}
impl LoginContext {
#[must_use]
pub fn with_form_state(self, form: FormState<LoginFormField>) -> Self {
Self { form, ..self }
}
#[must_use]
pub fn with_upstream_providers(self, providers: Vec<UpstreamOAuthProvider>) -> Self {
Self { providers, ..self }
}
#[must_use]
pub fn with_post_action(self, context: PostAuthContext) -> Self {
Self {
next: Some(context),
..self
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RegisterFormField {
Username,
Email,
Password,
PasswordConfirm,
AcceptTerms,
}
impl FormField for RegisterFormField {
fn keep(&self) -> bool {
match self {
Self::Username | Self::Email | Self::AcceptTerms => true,
Self::Password | Self::PasswordConfirm => false,
}
}
}
#[derive(Serialize, Default)]
pub struct RegisterContext {
form: FormState<RegisterFormField>,
next: Option<PostAuthContext>,
}
impl TemplateContext for RegisterContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![RegisterContext {
form: FormState::default(),
next: None,
}]
}
}
impl RegisterContext {
#[must_use]
pub fn with_form_state(self, form: FormState<RegisterFormField>) -> Self {
Self { form, ..self }
}
#[must_use]
pub fn with_post_action(self, next: PostAuthContext) -> Self {
Self {
next: Some(next),
..self
}
}
}
#[derive(Serialize)]
pub struct ConsentContext {
grant: AuthorizationGrant,
client: Client,
action: PostAuthAction,
}
impl TemplateContext for ConsentContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
Client::samples(now, rng)
.into_iter()
.map(|client| {
let mut grant = AuthorizationGrant::sample(now, rng);
let action = PostAuthAction::continue_grant(grant.id);
grant.client_id = client.id;
Self {
grant,
client,
action,
}
})
.collect()
}
}
impl ConsentContext {
#[must_use]
pub fn new(grant: AuthorizationGrant, client: Client) -> Self {
let action = PostAuthAction::continue_grant(grant.id);
Self {
grant,
client,
action,
}
}
}
#[derive(Serialize)]
#[serde(tag = "grant_type")]
enum PolicyViolationGrant {
#[serde(rename = "authorization_code")]
Authorization(AuthorizationGrant),
#[serde(rename = "urn:ietf:params:oauth:grant-type:device_code")]
DeviceCode(DeviceCodeGrant),
}
#[derive(Serialize)]
pub struct PolicyViolationContext {
grant: PolicyViolationGrant,
client: Client,
action: PostAuthAction,
}
impl TemplateContext for PolicyViolationContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
Client::samples(now, rng)
.into_iter()
.flat_map(|client| {
let mut grant = AuthorizationGrant::sample(now, rng);
grant.client_id = client.id;
let authorization_grant =
PolicyViolationContext::for_authorization_grant(grant, client.clone());
let device_code_grant = PolicyViolationContext::for_device_code_grant(
DeviceCodeGrant {
id: Ulid::from_datetime_with_source(now.into(), rng),
state: mas_data_model::DeviceCodeGrantState::Pending,
client_id: client.id,
scope: [OPENID].into_iter().collect(),
user_code: Alphanumeric.sample_string(rng, 6).to_uppercase(),
device_code: Alphanumeric.sample_string(rng, 32),
created_at: now - Duration::try_minutes(5).unwrap(),
expires_at: now + Duration::try_minutes(25).unwrap(),
ip_address: None,
user_agent: None,
},
client,
);
[authorization_grant, device_code_grant]
})
.collect()
}
}
impl PolicyViolationContext {
#[must_use]
pub const fn for_authorization_grant(grant: AuthorizationGrant, client: Client) -> Self {
let action = PostAuthAction::continue_grant(grant.id);
Self {
grant: PolicyViolationGrant::Authorization(grant),
client,
action,
}
}
#[must_use]
pub const fn for_device_code_grant(grant: DeviceCodeGrant, client: Client) -> Self {
let action = PostAuthAction::continue_device_code_grant(grant.id);
Self {
grant: PolicyViolationGrant::DeviceCode(grant),
client,
action,
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum ReauthFormField {
Password,
}
impl FormField for ReauthFormField {
fn keep(&self) -> bool {
match self {
Self::Password => false,
}
}
}
#[derive(Serialize, Default)]
pub struct ReauthContext {
form: FormState<ReauthFormField>,
next: Option<PostAuthContext>,
}
impl TemplateContext for ReauthContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![ReauthContext {
form: FormState::default(),
next: None,
}]
}
}
impl ReauthContext {
#[must_use]
pub fn with_form_state(self, form: FormState<ReauthFormField>) -> Self {
Self { form, ..self }
}
#[must_use]
pub fn with_post_action(self, next: PostAuthContext) -> Self {
Self {
next: Some(next),
..self
}
}
}
#[derive(Serialize)]
pub struct CompatSsoContext {
login: CompatSsoLogin,
action: PostAuthAction,
}
impl TemplateContext for CompatSsoContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![CompatSsoContext::new(CompatSsoLogin {
id,
redirect_uri: Url::parse("https://app.element.io/").unwrap(),
login_token: "abcdefghijklmnopqrstuvwxyz012345".into(),
created_at: now,
state: CompatSsoLoginState::Pending,
})]
}
}
impl CompatSsoContext {
#[must_use]
pub fn new(login: CompatSsoLogin) -> Self
where {
let action = PostAuthAction::continue_compat_sso_login(login.id);
Self { login, action }
}
}
#[derive(Serialize)]
pub struct EmailRecoveryContext {
user: User,
session: UserRecoverySession,
recovery_link: Url,
}
impl EmailRecoveryContext {
#[must_use]
pub fn new(user: User, session: UserRecoverySession, recovery_link: Url) -> Self {
Self {
user,
session,
recovery_link,
}
}
#[must_use]
pub fn user(&self) -> &User {
&self.user
}
#[must_use]
pub fn session(&self) -> &UserRecoverySession {
&self.session
}
}
impl TemplateContext for EmailRecoveryContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
User::samples(now, rng).into_iter().map(|user| {
let session = UserRecoverySession {
id: Ulid::from_datetime_with_source(now.into(), rng),
email: "hello@example.com".to_owned(),
user_agent: UserAgent::parse("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_4) AppleWebKit/536.30.1 (KHTML, like Gecko) Version/6.0.5 Safari/536.30.1".to_owned()),
ip_address: Some(IpAddr::from([192_u8, 0, 2, 1])),
locale: "en".to_owned(),
created_at: now,
consumed_at: None,
};
let link = "https://example.com/recovery/complete?ticket=abcdefghijklmnopqrstuvwxyz0123456789".parse().unwrap();
Self::new(user, session, link)
}).collect()
}
}
#[derive(Serialize)]
pub struct EmailVerificationContext {
user: User,
verification: UserEmailVerification,
}
impl EmailVerificationContext {
#[must_use]
pub fn new(user: User, verification: UserEmailVerification) -> Self {
Self { user, verification }
}
#[must_use]
pub fn user(&self) -> &User {
&self.user
}
#[must_use]
pub fn verification(&self) -> &UserEmailVerification {
&self.verification
}
}
impl TemplateContext for EmailVerificationContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
User::samples(now, rng)
.into_iter()
.map(|user| {
let email = UserEmail {
id: Ulid::from_datetime_with_source(now.into(), rng),
user_id: user.id,
email: "foobar@example.com".to_owned(),
created_at: now,
confirmed_at: None,
};
let verification = UserEmailVerification {
id: Ulid::from_datetime_with_source(now.into(), rng),
user_email_id: email.id,
code: "123456".to_owned(),
created_at: now,
state: mas_data_model::UserEmailVerificationState::Valid,
};
Self { user, verification }
})
.collect()
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum EmailVerificationFormField {
Code,
}
impl FormField for EmailVerificationFormField {
fn keep(&self) -> bool {
match self {
Self::Code => true,
}
}
}
#[derive(Serialize)]
pub struct EmailVerificationPageContext {
form: FormState<EmailVerificationFormField>,
email: UserEmail,
}
impl EmailVerificationPageContext {
#[must_use]
pub fn new(email: UserEmail) -> Self {
Self {
form: FormState::default(),
email,
}
}
#[must_use]
pub fn with_form_state(self, form: FormState<EmailVerificationFormField>) -> Self {
Self { form, ..self }
}
}
impl TemplateContext for EmailVerificationPageContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let email = UserEmail {
id: Ulid::from_datetime_with_source(now.into(), rng),
user_id: Ulid::from_datetime_with_source(now.into(), rng),
email: "foobar@example.com".to_owned(),
created_at: now,
confirmed_at: None,
};
vec![Self {
form: FormState::default(),
email,
}]
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum EmailAddFormField {
Email,
}
impl FormField for EmailAddFormField {
fn keep(&self) -> bool {
match self {
Self::Email => true,
}
}
}
#[derive(Serialize, Default)]
pub struct EmailAddContext {
form: FormState<EmailAddFormField>,
}
impl EmailAddContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_form_state(form: FormState<EmailAddFormField>) -> Self {
Self { form }
}
}
impl TemplateContext for EmailAddContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![Self::default()]
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RecoveryStartFormField {
Email,
}
impl FormField for RecoveryStartFormField {
fn keep(&self) -> bool {
match self {
Self::Email => true,
}
}
}
#[derive(Serialize, Default)]
pub struct RecoveryStartContext {
form: FormState<RecoveryStartFormField>,
}
impl RecoveryStartContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_form_state(self, form: FormState<RecoveryStartFormField>) -> Self {
Self { form }
}
}
impl TemplateContext for RecoveryStartContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![
Self::new(),
Self::new().with_form_state(
FormState::default()
.with_error_on_field(RecoveryStartFormField::Email, FieldError::Required),
),
Self::new().with_form_state(
FormState::default()
.with_error_on_field(RecoveryStartFormField::Email, FieldError::Invalid),
),
]
}
}
#[derive(Serialize)]
pub struct RecoveryProgressContext {
session: UserRecoverySession,
resend_failed_due_to_rate_limit: bool,
}
impl RecoveryProgressContext {
#[must_use]
pub fn new(session: UserRecoverySession, resend_failed_due_to_rate_limit: bool) -> Self {
Self {
session,
resend_failed_due_to_rate_limit,
}
}
}
impl TemplateContext for RecoveryProgressContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let session = UserRecoverySession {
id: Ulid::from_datetime_with_source(now.into(), rng),
email: "name@mail.com".to_owned(),
user_agent: UserAgent::parse("Mozilla/5.0".to_owned()),
ip_address: None,
locale: "en".to_owned(),
created_at: now,
consumed_at: None,
};
vec![
Self {
session: session.clone(),
resend_failed_due_to_rate_limit: false,
},
Self {
session,
resend_failed_due_to_rate_limit: true,
},
]
}
}
#[derive(Serialize)]
pub struct RecoveryExpiredContext {
session: UserRecoverySession,
}
impl RecoveryExpiredContext {
#[must_use]
pub fn new(session: UserRecoverySession) -> Self {
Self { session }
}
}
impl TemplateContext for RecoveryExpiredContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let session = UserRecoverySession {
id: Ulid::from_datetime_with_source(now.into(), rng),
email: "name@mail.com".to_owned(),
user_agent: UserAgent::parse("Mozilla/5.0".to_owned()),
ip_address: None,
locale: "en".to_owned(),
created_at: now,
consumed_at: None,
};
vec![Self { session }]
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RecoveryFinishFormField {
NewPassword,
NewPasswordConfirm,
}
impl FormField for RecoveryFinishFormField {
fn keep(&self) -> bool {
false
}
}
#[derive(Serialize)]
pub struct RecoveryFinishContext {
user: User,
form: FormState<RecoveryFinishFormField>,
}
impl RecoveryFinishContext {
#[must_use]
pub fn new(user: User) -> Self {
Self {
user,
form: FormState::default(),
}
}
#[must_use]
pub fn with_form_state(mut self, form: FormState<RecoveryFinishFormField>) -> Self {
self.form = form;
self
}
}
impl TemplateContext for RecoveryFinishContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
User::samples(now, rng)
.into_iter()
.flat_map(|user| {
vec![
Self::new(user.clone()),
Self::new(user.clone()).with_form_state(
FormState::default().with_error_on_field(
RecoveryFinishFormField::NewPassword,
FieldError::Invalid,
),
),
Self::new(user.clone()).with_form_state(
FormState::default().with_error_on_field(
RecoveryFinishFormField::NewPasswordConfirm,
FieldError::Invalid,
),
),
]
})
.collect()
}
}
#[derive(Serialize)]
pub struct UpstreamExistingLinkContext {
linked_user: User,
}
impl UpstreamExistingLinkContext {
#[must_use]
pub fn new(linked_user: User) -> Self {
Self { linked_user }
}
}
impl TemplateContext for UpstreamExistingLinkContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
User::samples(now, rng)
.into_iter()
.map(|linked_user| Self { linked_user })
.collect()
}
}
#[derive(Serialize)]
pub struct UpstreamSuggestLink {
post_logout_action: PostAuthAction,
}
impl UpstreamSuggestLink {
#[must_use]
pub fn new(link: &UpstreamOAuthLink) -> Self {
Self::for_link_id(link.id)
}
fn for_link_id(id: Ulid) -> Self {
let post_logout_action = PostAuthAction::link_upstream(id);
Self { post_logout_action }
}
}
impl TemplateContext for UpstreamSuggestLink {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![Self::for_link_id(id)]
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum UpstreamRegisterFormField {
Username,
AcceptTerms,
}
impl FormField for UpstreamRegisterFormField {
fn keep(&self) -> bool {
match self {
Self::Username | Self::AcceptTerms => true,
}
}
}
#[derive(Serialize, Default)]
pub struct UpstreamRegister {
imported_localpart: Option<String>,
force_localpart: bool,
imported_display_name: Option<String>,
force_display_name: bool,
imported_email: Option<String>,
force_email: bool,
form_state: FormState<UpstreamRegisterFormField>,
}
impl UpstreamRegister {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_localpart(&mut self, localpart: String, force: bool) {
self.imported_localpart = Some(localpart);
self.force_localpart = force;
}
#[must_use]
pub fn with_localpart(self, localpart: String, force: bool) -> Self {
Self {
imported_localpart: Some(localpart),
force_localpart: force,
..self
}
}
pub fn set_display_name(&mut self, display_name: String, force: bool) {
self.imported_display_name = Some(display_name);
self.force_display_name = force;
}
#[must_use]
pub fn with_display_name(self, display_name: String, force: bool) -> Self {
Self {
imported_display_name: Some(display_name),
force_display_name: force,
..self
}
}
pub fn set_email(&mut self, email: String, force: bool) {
self.imported_email = Some(email);
self.force_email = force;
}
#[must_use]
pub fn with_email(self, email: String, force: bool) -> Self {
Self {
imported_email: Some(email),
force_email: force,
..self
}
}
pub fn set_form_state(&mut self, form_state: FormState<UpstreamRegisterFormField>) {
self.form_state = form_state;
}
#[must_use]
pub fn with_form_state(self, form_state: FormState<UpstreamRegisterFormField>) -> Self {
Self { form_state, ..self }
}
}
impl TemplateContext for UpstreamRegister {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![Self::new()]
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DeviceLinkFormField {
Code,
}
impl FormField for DeviceLinkFormField {
fn keep(&self) -> bool {
match self {
Self::Code => true,
}
}
}
#[derive(Serialize, Default, Debug)]
pub struct DeviceLinkContext {
form_state: FormState<DeviceLinkFormField>,
}
impl DeviceLinkContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_form_state(mut self, form_state: FormState<DeviceLinkFormField>) -> Self {
self.form_state = form_state;
self
}
}
impl TemplateContext for DeviceLinkContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![
Self::new(),
Self::new().with_form_state(
FormState::default()
.with_error_on_field(DeviceLinkFormField::Code, FieldError::Required),
),
]
}
}
#[derive(Serialize, Debug)]
pub struct DeviceConsentContext {
grant: DeviceCodeGrant,
client: Client,
}
impl DeviceConsentContext {
#[must_use]
pub fn new(grant: DeviceCodeGrant, client: Client) -> Self {
Self { grant, client }
}
}
impl TemplateContext for DeviceConsentContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
Client::samples(now, rng)
.into_iter()
.map(|client| {
let grant = DeviceCodeGrant {
id: Ulid::from_datetime_with_source(now.into(), rng),
state: mas_data_model::DeviceCodeGrantState::Pending,
client_id: client.id,
scope: [OPENID].into_iter().collect(),
user_code: Alphanumeric.sample_string(rng, 6).to_uppercase(),
device_code: Alphanumeric.sample_string(rng, 32),
created_at: now - Duration::try_minutes(5).unwrap(),
expires_at: now + Duration::try_minutes(25).unwrap(),
ip_address: Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
user_agent: Some(UserAgent::parse("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.0.0 Safari/537.36".to_owned())),
};
Self { grant, client }
})
.collect()
}
}
#[derive(Serialize)]
pub struct FormPostContext<T> {
redirect_uri: Url,
params: T,
}
impl<T: TemplateContext> TemplateContext for FormPostContext<T> {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let sample_params = T::sample(now, rng);
sample_params
.into_iter()
.map(|params| FormPostContext {
redirect_uri: "https://example.com/callback".parse().unwrap(),
params,
})
.collect()
}
}
impl<T> FormPostContext<T> {
pub fn new(redirect_uri: Url, params: T) -> Self {
Self {
redirect_uri,
params,
}
}
}
#[derive(Default, Serialize, Debug, Clone)]
pub struct ErrorContext {
code: Option<&'static str>,
description: Option<String>,
details: Option<String>,
lang: Option<String>,
}
impl std::fmt::Display for ErrorContext {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if let Some(code) = &self.code {
writeln!(f, "code: {code}")?;
}
if let Some(description) = &self.description {
writeln!(f, "{description}")?;
}
if let Some(details) = &self.details {
writeln!(f, "details: {details}")?;
}
Ok(())
}
}
impl TemplateContext for ErrorContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![
Self::new()
.with_code("sample_error")
.with_description("A fancy description".into())
.with_details("Something happened".into()),
Self::new().with_code("another_error"),
Self::new(),
]
}
}
impl ErrorContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_code(mut self, code: &'static str) -> Self {
self.code = Some(code);
self
}
#[must_use]
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
#[must_use]
pub fn with_details(mut self, details: String) -> Self {
self.details = Some(details);
self
}
#[must_use]
pub fn with_language(mut self, lang: &DataLocale) -> Self {
self.lang = Some(lang.to_string());
self
}
#[must_use]
pub fn code(&self) -> Option<&'static str> {
self.code
}
#[must_use]
pub fn description(&self) -> Option<&str> {
self.description.as_deref()
}
#[must_use]
pub fn details(&self) -> Option<&str> {
self.details.as_deref()
}
}
#[derive(Serialize)]
pub struct NotFoundContext {
method: String,
version: String,
uri: String,
}
impl NotFoundContext {
#[must_use]
pub fn new(method: &Method, version: Version, uri: &Uri) -> Self {
Self {
method: method.to_string(),
version: format!("{version:?}"),
uri: uri.to_string(),
}
}
}
impl TemplateContext for NotFoundContext {
fn sample(_now: DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![
Self::new(&Method::GET, Version::HTTP_11, &"/".parse().unwrap()),
Self::new(&Method::POST, Version::HTTP_2, &"/foo/bar".parse().unwrap()),
Self::new(
&Method::PUT,
Version::HTTP_10,
&"/foo?bar=baz".parse().unwrap(),
),
]
}
}