mas_handlers/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7#![deny(clippy::future_not_send)]
8#![allow(
9    // Some axum handlers need that
10    clippy::unused_async,
11    // Because of how axum handlers work, we sometime have take many arguments
12    clippy::too_many_arguments,
13    // Code generated by tracing::instrument trigger this when returning an `impl Trait`
14    // See https://github.com/tokio-rs/tracing/issues/2613
15    clippy::let_with_type_underscore,
16)]
17
18use std::{
19    convert::Infallible,
20    sync::{Arc, LazyLock},
21    time::Duration,
22};
23
24use axum::{
25    Extension, Router,
26    extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
27    http::Method,
28    response::{Html, IntoResponse},
29    routing::{get, post},
30};
31use headers::HeaderName;
32use hyper::{
33    StatusCode, Version,
34    header::{
35        ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
36    },
37};
38use mas_axum_utils::{InternalError, cookies::CookieJar};
39use mas_data_model::SiteConfig;
40use mas_http::CorsLayerExt;
41use mas_keystore::{Encrypter, Keystore};
42use mas_matrix::HomeserverConnection;
43use mas_policy::Policy;
44use mas_router::{Route, UrlBuilder};
45use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng};
46use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
47use opentelemetry::metrics::Meter;
48use sqlx::PgPool;
49use tower::util::AndThenLayer;
50use tower_http::cors::{Any, CorsLayer};
51
52use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
53
54mod admin;
55mod compat;
56mod graphql;
57mod health;
58mod oauth2;
59pub mod passwords;
60pub mod upstream_oauth2;
61mod views;
62
63mod activity_tracker;
64mod captcha;
65mod preferred_language;
66mod rate_limit;
67mod session;
68#[cfg(test)]
69mod test_utils;
70
71static METER: LazyLock<Meter> = LazyLock::new(|| {
72    let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
73        .with_version(env!("CARGO_PKG_VERSION"))
74        .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
75        .build();
76
77    opentelemetry::global::meter_with_scope(scope)
78});
79
80/// Implement `From<E>` for `RouteError`, for "internal server error" kind of
81/// errors.
82#[macro_export]
83macro_rules! impl_from_error_for_route {
84    ($route_error:ty : $error:ty) => {
85        impl From<$error> for $route_error {
86            fn from(e: $error) -> Self {
87                Self::Internal(Box::new(e))
88            }
89        }
90    };
91    ($error:ty) => {
92        impl_from_error_for_route!(self::RouteError: $error);
93    };
94}
95
96pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
97
98pub use self::{
99    activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
100    admin::router as admin_api_router,
101    graphql::{
102        Schema as GraphQLSchema, schema as graphql_schema, schema_builder as graphql_schema_builder,
103    },
104    preferred_language::PreferredLanguage,
105    rate_limit::{Limiter, RequesterFingerprint},
106    upstream_oauth2::cache::MetadataCache,
107};
108
109pub fn healthcheck_router<S>() -> Router<S>
110where
111    S: Clone + Send + Sync + 'static,
112    PgPool: FromRef<S>,
113{
114    Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
115}
116
117pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
118where
119    S: Clone + Send + Sync + 'static,
120    graphql::Schema: FromRef<S>,
121    BoundActivityTracker: FromRequestParts<S>,
122    BoxRepository: FromRequestParts<S>,
123    BoxClock: FromRequestParts<S>,
124    Encrypter: FromRef<S>,
125    CookieJar: FromRequestParts<S>,
126    Limiter: FromRef<S>,
127    RequesterFingerprint: FromRequestParts<S>,
128{
129    let mut router = Router::new()
130        .route(
131            mas_router::GraphQL::route(),
132            get(self::graphql::get).post(self::graphql::post),
133        )
134        // Pass the undocumented_oauth2_access parameter through the request extension, as it is
135        // per-listener
136        .layer(Extension(ExtraRouterParameters {
137            undocumented_oauth2_access,
138        }))
139        .layer(
140            CorsLayer::new()
141                .allow_origin(Any)
142                .allow_methods(Any)
143                .allow_otel_headers([
144                    AUTHORIZATION,
145                    ACCEPT,
146                    ACCEPT_LANGUAGE,
147                    CONTENT_LANGUAGE,
148                    CONTENT_TYPE,
149                ]),
150        );
151
152    if playground {
153        router = router.route(
154            mas_router::GraphQLPlayground::route(),
155            get(self::graphql::playground),
156        );
157    }
158
159    router
160}
161
162pub fn discovery_router<S>() -> Router<S>
163where
164    S: Clone + Send + Sync + 'static,
165    Keystore: FromRef<S>,
166    SiteConfig: FromRef<S>,
167    UrlBuilder: FromRef<S>,
168    BoxClock: FromRequestParts<S>,
169    BoxRng: FromRequestParts<S>,
170{
171    Router::new()
172        .route(
173            mas_router::OidcConfiguration::route(),
174            get(self::oauth2::discovery::get),
175        )
176        .route(
177            mas_router::Webfinger::route(),
178            get(self::oauth2::webfinger::get),
179        )
180        .layer(
181            CorsLayer::new()
182                .allow_origin(Any)
183                .allow_methods(Any)
184                .allow_otel_headers([
185                    AUTHORIZATION,
186                    ACCEPT,
187                    ACCEPT_LANGUAGE,
188                    CONTENT_LANGUAGE,
189                    CONTENT_TYPE,
190                ])
191                .max_age(Duration::from_secs(60 * 60)),
192        )
193}
194
195pub fn api_router<S>() -> Router<S>
196where
197    S: Clone + Send + Sync + 'static,
198    Keystore: FromRef<S>,
199    UrlBuilder: FromRef<S>,
200    BoxRepository: FromRequestParts<S>,
201    ActivityTracker: FromRequestParts<S>,
202    BoundActivityTracker: FromRequestParts<S>,
203    Encrypter: FromRef<S>,
204    reqwest::Client: FromRef<S>,
205    SiteConfig: FromRef<S>,
206    Templates: FromRef<S>,
207    Arc<dyn HomeserverConnection>: FromRef<S>,
208    BoxClock: FromRequestParts<S>,
209    BoxRng: FromRequestParts<S>,
210    Policy: FromRequestParts<S>,
211{
212    // All those routes are API-like, with a common CORS layer
213    Router::new()
214        .route(
215            mas_router::OAuth2Keys::route(),
216            get(self::oauth2::keys::get),
217        )
218        .route(
219            mas_router::OidcUserinfo::route(),
220            get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
221        )
222        .route(
223            mas_router::OAuth2Introspection::route(),
224            post(self::oauth2::introspection::post),
225        )
226        .route(
227            mas_router::OAuth2Revocation::route(),
228            post(self::oauth2::revoke::post),
229        )
230        .route(
231            mas_router::OAuth2TokenEndpoint::route(),
232            post(self::oauth2::token::post),
233        )
234        .route(
235            mas_router::OAuth2RegistrationEndpoint::route(),
236            post(self::oauth2::registration::post),
237        )
238        .route(
239            mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
240            post(self::oauth2::device::authorize::post),
241        )
242        .layer(
243            CorsLayer::new()
244                .allow_origin(Any)
245                .allow_methods(Any)
246                .allow_otel_headers([
247                    AUTHORIZATION,
248                    ACCEPT,
249                    ACCEPT_LANGUAGE,
250                    CONTENT_LANGUAGE,
251                    CONTENT_TYPE,
252                    // Swagger will send this header, so we have to allow it to avoid CORS errors
253                    HeaderName::from_static("x-requested-with"),
254                ])
255                .max_age(Duration::from_secs(60 * 60)),
256        )
257}
258
259#[allow(clippy::trait_duplication_in_bounds)]
260pub fn compat_router<S>(templates: Templates) -> Router<S>
261where
262    S: Clone + Send + Sync + 'static,
263    UrlBuilder: FromRef<S>,
264    SiteConfig: FromRef<S>,
265    Arc<dyn HomeserverConnection>: FromRef<S>,
266    PasswordManager: FromRef<S>,
267    Limiter: FromRef<S>,
268    BoxRepositoryFactory: FromRef<S>,
269    BoundActivityTracker: FromRequestParts<S>,
270    RequesterFingerprint: FromRequestParts<S>,
271    BoxRepository: FromRequestParts<S>,
272    BoxClock: FromRequestParts<S>,
273    BoxRng: FromRequestParts<S>,
274{
275    // A sub-router for human-facing routes with error handling
276    let human_router = Router::new()
277        .route(
278            mas_router::CompatLoginSsoRedirect::route(),
279            get(self::compat::login_sso_redirect::get),
280        )
281        .route(
282            mas_router::CompatLoginSsoRedirectIdp::route(),
283            get(self::compat::login_sso_redirect::get),
284        )
285        .route(
286            mas_router::CompatLoginSsoRedirectSlash::route(),
287            get(self::compat::login_sso_redirect::get),
288        )
289        .layer(AndThenLayer::new(
290            async move |response: axum::response::Response| {
291                Ok::<_, Infallible>(recover_error(&templates, response))
292            },
293        ));
294
295    // A sub-router for API-facing routes with CORS
296    let api_router = Router::new()
297        .route(
298            mas_router::CompatLogin::route(),
299            get(self::compat::login::get).post(self::compat::login::post),
300        )
301        .route(
302            mas_router::CompatLogout::route(),
303            post(self::compat::logout::post),
304        )
305        .route(
306            mas_router::CompatLogoutAll::route(),
307            post(self::compat::logout_all::post),
308        )
309        .route(
310            mas_router::CompatRefresh::route(),
311            post(self::compat::refresh::post),
312        )
313        .layer(
314            CorsLayer::new()
315                .allow_origin(Any)
316                .allow_methods(Any)
317                .allow_otel_headers([
318                    AUTHORIZATION,
319                    ACCEPT,
320                    ACCEPT_LANGUAGE,
321                    CONTENT_LANGUAGE,
322                    CONTENT_TYPE,
323                    HeaderName::from_static("x-requested-with"),
324                ])
325                .max_age(Duration::from_secs(60 * 60)),
326        );
327
328    Router::new().merge(human_router).merge(api_router)
329}
330
331pub fn human_router<S>(templates: Templates) -> Router<S>
332where
333    S: Clone + Send + Sync + 'static,
334    UrlBuilder: FromRef<S>,
335    PreferredLanguage: FromRequestParts<S>,
336    BoxRepository: FromRequestParts<S>,
337    CookieJar: FromRequestParts<S>,
338    BoundActivityTracker: FromRequestParts<S>,
339    RequesterFingerprint: FromRequestParts<S>,
340    Encrypter: FromRef<S>,
341    Templates: FromRef<S>,
342    Keystore: FromRef<S>,
343    PasswordManager: FromRef<S>,
344    MetadataCache: FromRef<S>,
345    SiteConfig: FromRef<S>,
346    Limiter: FromRef<S>,
347    reqwest::Client: FromRef<S>,
348    Arc<dyn HomeserverConnection>: FromRef<S>,
349    BoxClock: FromRequestParts<S>,
350    BoxRng: FromRequestParts<S>,
351    Policy: FromRequestParts<S>,
352{
353    Router::new()
354        // XXX: hard-coded redirect from /account to /account/
355        .route(
356            "/account",
357            get(
358                async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
359                    let prefix = url_builder.prefix().unwrap_or_default();
360                    let route = mas_router::Account::route();
361                    let destination = if let Some(query) = query {
362                        format!("{prefix}{route}?{query}")
363                    } else {
364                        format!("{prefix}{route}")
365                    };
366
367                    axum::response::Redirect::to(&destination)
368                },
369            ),
370        )
371        .route(mas_router::Account::route(), get(self::views::app::get))
372        .route(
373            mas_router::AccountWildcard::route(),
374            get(self::views::app::get),
375        )
376        .route(
377            mas_router::AccountRecoveryFinish::route(),
378            get(self::views::app::get_anonymous),
379        )
380        .route(
381            mas_router::ChangePasswordDiscovery::route(),
382            get(async |State(url_builder): State<UrlBuilder>| {
383                url_builder.redirect(&mas_router::AccountPasswordChange)
384            }),
385        )
386        .route(mas_router::Index::route(), get(self::views::index::get))
387        .route(
388            mas_router::Login::route(),
389            get(self::views::login::get).post(self::views::login::post),
390        )
391        .route(mas_router::Logout::route(), post(self::views::logout::post))
392        .route(
393            mas_router::Register::route(),
394            get(self::views::register::get),
395        )
396        .route(
397            mas_router::PasswordRegister::route(),
398            get(self::views::register::password::get).post(self::views::register::password::post),
399        )
400        .route(
401            mas_router::RegisterVerifyEmail::route(),
402            get(self::views::register::steps::verify_email::get)
403                .post(self::views::register::steps::verify_email::post),
404        )
405        .route(
406            mas_router::RegisterToken::route(),
407            get(self::views::register::steps::registration_token::get)
408                .post(self::views::register::steps::registration_token::post),
409        )
410        .route(
411            mas_router::RegisterDisplayName::route(),
412            get(self::views::register::steps::display_name::get)
413                .post(self::views::register::steps::display_name::post),
414        )
415        .route(
416            mas_router::RegisterFinish::route(),
417            get(self::views::register::steps::finish::get),
418        )
419        .route(
420            mas_router::AccountRecoveryStart::route(),
421            get(self::views::recovery::start::get).post(self::views::recovery::start::post),
422        )
423        .route(
424            mas_router::AccountRecoveryProgress::route(),
425            get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
426        )
427        .route(
428            mas_router::OAuth2AuthorizationEndpoint::route(),
429            get(self::oauth2::authorization::get),
430        )
431        .route(
432            mas_router::Consent::route(),
433            get(self::oauth2::authorization::consent::get)
434                .post(self::oauth2::authorization::consent::post),
435        )
436        .route(
437            mas_router::CompatLoginSsoComplete::route(),
438            get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
439        )
440        .route(
441            mas_router::UpstreamOAuth2Authorize::route(),
442            get(self::upstream_oauth2::authorize::get),
443        )
444        .route(
445            mas_router::UpstreamOAuth2Callback::route(),
446            get(self::upstream_oauth2::callback::handler)
447                .post(self::upstream_oauth2::callback::handler),
448        )
449        .route(
450            mas_router::UpstreamOAuth2Link::route(),
451            get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
452        )
453        .route(
454            mas_router::UpstreamOAuth2BackchannelLogout::route(),
455            post(self::upstream_oauth2::backchannel_logout::post),
456        )
457        .route(
458            mas_router::DeviceCodeLink::route(),
459            get(self::oauth2::device::link::get),
460        )
461        .route(
462            mas_router::DeviceCodeConsent::route(),
463            get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
464        )
465        .layer(AndThenLayer::new(
466            async move |response: axum::response::Response| {
467                Ok::<_, Infallible>(recover_error(&templates, response))
468            },
469        ))
470}
471
472fn recover_error(
473    templates: &Templates,
474    response: axum::response::Response,
475) -> axum::response::Response {
476    // Error responses should have an ErrorContext attached to them
477    let ext = response.extensions().get::<ErrorContext>();
478    if let Some(ctx) = ext {
479        if let Ok(res) = templates.render_error(ctx) {
480            let (mut parts, _original_body) = response.into_parts();
481            parts.headers.remove(CONTENT_TYPE);
482            parts.headers.remove(CONTENT_LENGTH);
483            return (parts, Html(res)).into_response();
484        }
485    }
486
487    response
488}
489
490/// The fallback handler for all routes that don't match anything else.
491///
492/// # Errors
493///
494/// Returns an error if the template rendering fails.
495pub async fn fallback(
496    State(templates): State<Templates>,
497    OriginalUri(uri): OriginalUri,
498    method: Method,
499    version: Version,
500    PreferredLanguage(locale): PreferredLanguage,
501) -> Result<impl IntoResponse, InternalError> {
502    let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
503    // XXX: this should look at the Accept header and return JSON if requested
504
505    let res = templates.render_not_found(&ctx)?;
506
507    Ok((StatusCode::NOT_FOUND, Html(res)))
508}