mas_handlers/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7#![deny(clippy::future_not_send)]
8#![allow(
9    // Some axum handlers need that
10    clippy::unused_async,
11    // Because of how axum handlers work, we sometime have take many arguments
12    clippy::too_many_arguments,
13    // Code generated by tracing::instrument trigger this when returning an `impl Trait`
14    // See https://github.com/tokio-rs/tracing/issues/2613
15    clippy::let_with_type_underscore,
16)]
17
18use std::{
19    convert::Infallible,
20    sync::{Arc, LazyLock},
21    time::Duration,
22};
23
24use axum::{
25    Extension, Router,
26    extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
27    http::Method,
28    response::{Html, IntoResponse},
29    routing::{get, post},
30};
31use headers::HeaderName;
32use hyper::{
33    StatusCode, Version,
34    header::{
35        ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
36    },
37};
38use mas_axum_utils::{InternalError, cookies::CookieJar};
39use mas_data_model::SiteConfig;
40use mas_http::CorsLayerExt;
41use mas_keystore::{Encrypter, Keystore};
42use mas_matrix::HomeserverConnection;
43use mas_policy::Policy;
44use mas_router::{Route, UrlBuilder};
45use mas_storage::{BoxRepository, BoxRepositoryFactory};
46use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
47use opentelemetry::metrics::Meter;
48use sqlx::PgPool;
49use tower::util::AndThenLayer;
50use tower_http::cors::{Any, CorsLayer};
51
52use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
53
54mod admin;
55mod compat;
56mod graphql;
57mod health;
58mod oauth2;
59pub mod passwords;
60pub mod upstream_oauth2;
61mod views;
62
63mod activity_tracker;
64mod captcha;
65#[cfg(test)]
66mod cleanup_tests;
67mod preferred_language;
68mod rate_limit;
69mod session;
70#[cfg(test)]
71mod test_utils;
72
73static METER: LazyLock<Meter> = LazyLock::new(|| {
74    let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
75        .with_version(env!("CARGO_PKG_VERSION"))
76        .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
77        .build();
78
79    opentelemetry::global::meter_with_scope(scope)
80});
81
82/// Implement `From<E>` for `RouteError`, for "internal server error" kind of
83/// errors.
84#[macro_export]
85macro_rules! impl_from_error_for_route {
86    ($route_error:ty : $error:ty) => {
87        impl From<$error> for $route_error {
88            fn from(e: $error) -> Self {
89                Self::Internal(Box::new(e))
90            }
91        }
92    };
93    ($error:ty) => {
94        impl_from_error_for_route!(self::RouteError: $error);
95    };
96}
97
98pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
99use mas_data_model::{BoxClock, BoxRng};
100
101pub use self::{
102    activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
103    admin::router as admin_api_router,
104    graphql::{
105        Schema as GraphQLSchema, schema as graphql_schema, schema_builder as graphql_schema_builder,
106    },
107    preferred_language::PreferredLanguage,
108    rate_limit::{Limiter, RequesterFingerprint},
109    upstream_oauth2::cache::MetadataCache,
110};
111
112pub fn healthcheck_router<S>() -> Router<S>
113where
114    S: Clone + Send + Sync + 'static,
115    PgPool: FromRef<S>,
116{
117    Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
118}
119
120pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
121where
122    S: Clone + Send + Sync + 'static,
123    graphql::Schema: FromRef<S>,
124    BoundActivityTracker: FromRequestParts<S>,
125    BoxRepository: FromRequestParts<S>,
126    BoxClock: FromRequestParts<S>,
127    Encrypter: FromRef<S>,
128    CookieJar: FromRequestParts<S>,
129    Limiter: FromRef<S>,
130    RequesterFingerprint: FromRequestParts<S>,
131{
132    let mut router = Router::new()
133        .route(
134            mas_router::GraphQL::route(),
135            get(self::graphql::get).post(self::graphql::post),
136        )
137        // Pass the undocumented_oauth2_access parameter through the request extension, as it is
138        // per-listener
139        .layer(Extension(ExtraRouterParameters {
140            undocumented_oauth2_access,
141        }))
142        .layer(
143            CorsLayer::new()
144                .allow_origin(Any)
145                .allow_methods(Any)
146                .allow_otel_headers([
147                    AUTHORIZATION,
148                    ACCEPT,
149                    ACCEPT_LANGUAGE,
150                    CONTENT_LANGUAGE,
151                    CONTENT_TYPE,
152                ]),
153        );
154
155    if playground {
156        router = router.route(
157            mas_router::GraphQLPlayground::route(),
158            get(self::graphql::playground),
159        );
160    }
161
162    router
163}
164
165pub fn discovery_router<S>() -> Router<S>
166where
167    S: Clone + Send + Sync + 'static,
168    Keystore: FromRef<S>,
169    SiteConfig: FromRef<S>,
170    UrlBuilder: FromRef<S>,
171    BoxClock: FromRequestParts<S>,
172    BoxRng: FromRequestParts<S>,
173{
174    Router::new()
175        .route(
176            mas_router::OidcConfiguration::route(),
177            get(self::oauth2::discovery::get),
178        )
179        .route(
180            mas_router::Webfinger::route(),
181            get(self::oauth2::webfinger::get),
182        )
183        .layer(
184            CorsLayer::new()
185                .allow_origin(Any)
186                .allow_methods(Any)
187                .allow_otel_headers([
188                    AUTHORIZATION,
189                    ACCEPT,
190                    ACCEPT_LANGUAGE,
191                    CONTENT_LANGUAGE,
192                    CONTENT_TYPE,
193                ])
194                .max_age(Duration::from_secs(60 * 60)),
195        )
196}
197
198pub fn api_router<S>() -> Router<S>
199where
200    S: Clone + Send + Sync + 'static,
201    Keystore: FromRef<S>,
202    UrlBuilder: FromRef<S>,
203    BoxRepository: FromRequestParts<S>,
204    ActivityTracker: FromRequestParts<S>,
205    BoundActivityTracker: FromRequestParts<S>,
206    Encrypter: FromRef<S>,
207    reqwest::Client: FromRef<S>,
208    SiteConfig: FromRef<S>,
209    Templates: FromRef<S>,
210    Arc<dyn HomeserverConnection>: FromRef<S>,
211    BoxClock: FromRequestParts<S>,
212    BoxRng: FromRequestParts<S>,
213    Policy: FromRequestParts<S>,
214{
215    // All those routes are API-like, with a common CORS layer
216    Router::new()
217        .route(
218            mas_router::OAuth2Keys::route(),
219            get(self::oauth2::keys::get),
220        )
221        .route(
222            mas_router::OidcUserinfo::route(),
223            get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
224        )
225        .route(
226            mas_router::OAuth2Introspection::route(),
227            post(self::oauth2::introspection::post),
228        )
229        .route(
230            mas_router::OAuth2Revocation::route(),
231            post(self::oauth2::revoke::post),
232        )
233        .route(
234            mas_router::OAuth2TokenEndpoint::route(),
235            post(self::oauth2::token::post),
236        )
237        .route(
238            mas_router::OAuth2RegistrationEndpoint::route(),
239            post(self::oauth2::registration::post),
240        )
241        .route(
242            mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
243            post(self::oauth2::device::authorize::post),
244        )
245        .layer(
246            CorsLayer::new()
247                .allow_origin(Any)
248                .allow_methods(Any)
249                .allow_otel_headers([
250                    AUTHORIZATION,
251                    ACCEPT,
252                    ACCEPT_LANGUAGE,
253                    CONTENT_LANGUAGE,
254                    CONTENT_TYPE,
255                    // Swagger will send this header, so we have to allow it to avoid CORS errors
256                    HeaderName::from_static("x-requested-with"),
257                ])
258                .max_age(Duration::from_secs(60 * 60)),
259        )
260}
261
262#[allow(clippy::trait_duplication_in_bounds)]
263pub fn compat_router<S>(templates: Templates) -> Router<S>
264where
265    S: Clone + Send + Sync + 'static,
266    UrlBuilder: FromRef<S>,
267    SiteConfig: FromRef<S>,
268    Arc<dyn HomeserverConnection>: FromRef<S>,
269    PasswordManager: FromRef<S>,
270    Limiter: FromRef<S>,
271    BoxRepositoryFactory: FromRef<S>,
272    BoundActivityTracker: FromRequestParts<S>,
273    RequesterFingerprint: FromRequestParts<S>,
274    BoxRepository: FromRequestParts<S>,
275    BoxClock: FromRequestParts<S>,
276    BoxRng: FromRequestParts<S>,
277    Policy: FromRequestParts<S>,
278{
279    // A sub-router for human-facing routes with error handling
280    let human_router = Router::new()
281        .route(
282            mas_router::CompatLoginSsoRedirect::route(),
283            get(self::compat::login_sso_redirect::get),
284        )
285        .route(
286            mas_router::CompatLoginSsoRedirectIdp::route(),
287            get(self::compat::login_sso_redirect::get),
288        )
289        .route(
290            mas_router::CompatLoginSsoRedirectSlash::route(),
291            get(self::compat::login_sso_redirect::get),
292        )
293        .layer(AndThenLayer::new(
294            async move |response: axum::response::Response| {
295                Ok::<_, Infallible>(recover_error(&templates, response))
296            },
297        ));
298
299    // A sub-router for API-facing routes with CORS
300    let api_router = Router::new()
301        .route(
302            mas_router::CompatLogin::route(),
303            get(self::compat::login::get).post(self::compat::login::post),
304        )
305        .route(
306            mas_router::CompatLogout::route(),
307            post(self::compat::logout::post),
308        )
309        .route(
310            mas_router::CompatLogoutAll::route(),
311            post(self::compat::logout_all::post),
312        )
313        .route(
314            mas_router::CompatRefresh::route(),
315            post(self::compat::refresh::post),
316        )
317        .layer(
318            CorsLayer::new()
319                .allow_origin(Any)
320                .allow_methods(Any)
321                .allow_otel_headers([
322                    AUTHORIZATION,
323                    ACCEPT,
324                    ACCEPT_LANGUAGE,
325                    CONTENT_LANGUAGE,
326                    CONTENT_TYPE,
327                    HeaderName::from_static("x-requested-with"),
328                ])
329                .max_age(Duration::from_secs(60 * 60)),
330        );
331
332    Router::new().merge(human_router).merge(api_router)
333}
334
335pub fn human_router<S>(templates: Templates) -> Router<S>
336where
337    S: Clone + Send + Sync + 'static,
338    UrlBuilder: FromRef<S>,
339    PreferredLanguage: FromRequestParts<S>,
340    BoxRepository: FromRequestParts<S>,
341    CookieJar: FromRequestParts<S>,
342    BoundActivityTracker: FromRequestParts<S>,
343    RequesterFingerprint: FromRequestParts<S>,
344    Encrypter: FromRef<S>,
345    Templates: FromRef<S>,
346    Keystore: FromRef<S>,
347    PasswordManager: FromRef<S>,
348    MetadataCache: FromRef<S>,
349    SiteConfig: FromRef<S>,
350    Limiter: FromRef<S>,
351    reqwest::Client: FromRef<S>,
352    Arc<dyn HomeserverConnection>: FromRef<S>,
353    BoxClock: FromRequestParts<S>,
354    BoxRng: FromRequestParts<S>,
355    Policy: FromRequestParts<S>,
356{
357    Router::new()
358        // XXX: hard-coded redirect from /account to /account/
359        .route(
360            "/account",
361            get(
362                async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
363                    let prefix = url_builder.prefix().unwrap_or_default();
364                    let route = mas_router::Account::route();
365                    let destination = if let Some(query) = query {
366                        format!("{prefix}{route}?{query}")
367                    } else {
368                        format!("{prefix}{route}")
369                    };
370
371                    axum::response::Redirect::to(&destination)
372                },
373            ),
374        )
375        .route(mas_router::Account::route(), get(self::views::app::get))
376        .route(
377            mas_router::AccountWildcard::route(),
378            get(self::views::app::get),
379        )
380        .route(
381            mas_router::AccountRecoveryFinish::route(),
382            get(self::views::app::get_anonymous),
383        )
384        .route(
385            mas_router::ChangePasswordDiscovery::route(),
386            get(async |State(url_builder): State<UrlBuilder>| {
387                url_builder.redirect(&mas_router::AccountPasswordChange)
388            }),
389        )
390        .route(mas_router::Index::route(), get(self::views::index::get))
391        .route(
392            mas_router::Login::route(),
393            get(self::views::login::get).post(self::views::login::post),
394        )
395        .route(mas_router::Logout::route(), post(self::views::logout::post))
396        .route(
397            mas_router::Register::route(),
398            get(self::views::register::get),
399        )
400        .route(
401            mas_router::PasswordRegister::route(),
402            get(self::views::register::password::get).post(self::views::register::password::post),
403        )
404        .route(
405            mas_router::RegisterVerifyEmail::route(),
406            get(self::views::register::steps::verify_email::get)
407                .post(self::views::register::steps::verify_email::post),
408        )
409        .route(
410            mas_router::RegisterToken::route(),
411            get(self::views::register::steps::registration_token::get)
412                .post(self::views::register::steps::registration_token::post),
413        )
414        .route(
415            mas_router::RegisterDisplayName::route(),
416            get(self::views::register::steps::display_name::get)
417                .post(self::views::register::steps::display_name::post),
418        )
419        .route(
420            mas_router::RegisterFinish::route(),
421            get(self::views::register::steps::finish::get),
422        )
423        .route(
424            mas_router::AccountRecoveryStart::route(),
425            get(self::views::recovery::start::get).post(self::views::recovery::start::post),
426        )
427        .route(
428            mas_router::AccountRecoveryProgress::route(),
429            get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
430        )
431        .route(
432            mas_router::OAuth2AuthorizationEndpoint::route(),
433            get(self::oauth2::authorization::get),
434        )
435        .route(
436            mas_router::Consent::route(),
437            get(self::oauth2::authorization::consent::get)
438                .post(self::oauth2::authorization::consent::post),
439        )
440        .route(
441            mas_router::CompatLoginSsoComplete::route(),
442            get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
443        )
444        .route(
445            mas_router::UpstreamOAuth2Authorize::route(),
446            get(self::upstream_oauth2::authorize::get),
447        )
448        .route(
449            mas_router::UpstreamOAuth2Callback::route(),
450            get(self::upstream_oauth2::callback::handler)
451                .post(self::upstream_oauth2::callback::handler),
452        )
453        .route(
454            mas_router::UpstreamOAuth2Link::route(),
455            get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
456        )
457        .route(
458            mas_router::UpstreamOAuth2BackchannelLogout::route(),
459            post(self::upstream_oauth2::backchannel_logout::post),
460        )
461        .route(
462            mas_router::DeviceCodeLink::route(),
463            get(self::oauth2::device::link::get),
464        )
465        .route(
466            mas_router::DeviceCodeConsent::route(),
467            get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
468        )
469        .layer(AndThenLayer::new(
470            async move |response: axum::response::Response| {
471                Ok::<_, Infallible>(recover_error(&templates, response))
472            },
473        ))
474}
475
476fn recover_error(
477    templates: &Templates,
478    response: axum::response::Response,
479) -> axum::response::Response {
480    // Error responses should have an ErrorContext attached to them
481    let ext = response.extensions().get::<ErrorContext>();
482    if let Some(ctx) = ext
483        && let Ok(res) = templates.render_error(ctx)
484    {
485        let (mut parts, _original_body) = response.into_parts();
486        parts.headers.remove(CONTENT_TYPE);
487        parts.headers.remove(CONTENT_LENGTH);
488        return (parts, Html(res)).into_response();
489    }
490
491    response
492}
493
494/// The fallback handler for all routes that don't match anything else.
495///
496/// # Errors
497///
498/// Returns an error if the template rendering fails.
499pub async fn fallback(
500    State(templates): State<Templates>,
501    OriginalUri(uri): OriginalUri,
502    method: Method,
503    version: Version,
504    PreferredLanguage(locale): PreferredLanguage,
505) -> Result<impl IntoResponse, InternalError> {
506    let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
507    // XXX: this should look at the Accept header and return JSON if requested
508
509    let res = templates.render_not_found(&ctx)?;
510
511    Ok((StatusCode::NOT_FOUND, Html(res)))
512}