1#![deny(clippy::future_not_send)]
8#![allow(
9 clippy::unused_async,
11 clippy::too_many_arguments,
13 clippy::let_with_type_underscore,
16)]
17
18use std::{
19 convert::Infallible,
20 sync::{Arc, LazyLock},
21 time::Duration,
22};
23
24use axum::{
25 Extension, Router,
26 extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
27 http::Method,
28 response::{Html, IntoResponse},
29 routing::{get, post},
30};
31use headers::HeaderName;
32use hyper::{
33 StatusCode, Version,
34 header::{
35 ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
36 },
37};
38use mas_axum_utils::{InternalError, cookies::CookieJar};
39use mas_data_model::SiteConfig;
40use mas_http::CorsLayerExt;
41use mas_keystore::{Encrypter, Keystore};
42use mas_matrix::HomeserverConnection;
43use mas_policy::Policy;
44use mas_router::{Route, UrlBuilder};
45use mas_storage::{BoxRepository, BoxRepositoryFactory};
46use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
47use opentelemetry::metrics::Meter;
48use sqlx::PgPool;
49use tower::util::AndThenLayer;
50use tower_http::cors::{Any, CorsLayer};
51
52use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
53
54mod admin;
55mod compat;
56mod graphql;
57mod health;
58mod oauth2;
59pub mod passwords;
60pub mod upstream_oauth2;
61mod views;
62
63mod activity_tracker;
64mod captcha;
65#[cfg(test)]
66mod cleanup_tests;
67mod preferred_language;
68mod rate_limit;
69mod session;
70#[cfg(test)]
71mod test_utils;
72
73static METER: LazyLock<Meter> = LazyLock::new(|| {
74 let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
75 .with_version(env!("CARGO_PKG_VERSION"))
76 .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
77 .build();
78
79 opentelemetry::global::meter_with_scope(scope)
80});
81
82#[macro_export]
85macro_rules! impl_from_error_for_route {
86 ($route_error:ty : $error:ty) => {
87 impl From<$error> for $route_error {
88 fn from(e: $error) -> Self {
89 Self::Internal(Box::new(e))
90 }
91 }
92 };
93 ($error:ty) => {
94 impl_from_error_for_route!(self::RouteError: $error);
95 };
96}
97
98pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
99use mas_data_model::{BoxClock, BoxRng};
100
101pub use self::{
102 activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
103 admin::router as admin_api_router,
104 graphql::{
105 Schema as GraphQLSchema, schema as graphql_schema, schema_builder as graphql_schema_builder,
106 },
107 preferred_language::PreferredLanguage,
108 rate_limit::{Limiter, RequesterFingerprint},
109 upstream_oauth2::cache::MetadataCache,
110};
111
112pub fn healthcheck_router<S>() -> Router<S>
113where
114 S: Clone + Send + Sync + 'static,
115 PgPool: FromRef<S>,
116{
117 Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
118}
119
120pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
121where
122 S: Clone + Send + Sync + 'static,
123 graphql::Schema: FromRef<S>,
124 BoundActivityTracker: FromRequestParts<S>,
125 BoxRepository: FromRequestParts<S>,
126 BoxClock: FromRequestParts<S>,
127 Encrypter: FromRef<S>,
128 CookieJar: FromRequestParts<S>,
129 Limiter: FromRef<S>,
130 RequesterFingerprint: FromRequestParts<S>,
131{
132 let mut router = Router::new()
133 .route(
134 mas_router::GraphQL::route(),
135 get(self::graphql::get).post(self::graphql::post),
136 )
137 .layer(Extension(ExtraRouterParameters {
140 undocumented_oauth2_access,
141 }))
142 .layer(
143 CorsLayer::new()
144 .allow_origin(Any)
145 .allow_methods(Any)
146 .allow_otel_headers([
147 AUTHORIZATION,
148 ACCEPT,
149 ACCEPT_LANGUAGE,
150 CONTENT_LANGUAGE,
151 CONTENT_TYPE,
152 ]),
153 );
154
155 if playground {
156 router = router.route(
157 mas_router::GraphQLPlayground::route(),
158 get(self::graphql::playground),
159 );
160 }
161
162 router
163}
164
165pub fn discovery_router<S>() -> Router<S>
166where
167 S: Clone + Send + Sync + 'static,
168 Keystore: FromRef<S>,
169 SiteConfig: FromRef<S>,
170 UrlBuilder: FromRef<S>,
171 BoxClock: FromRequestParts<S>,
172 BoxRng: FromRequestParts<S>,
173{
174 Router::new()
175 .route(
176 mas_router::OidcConfiguration::route(),
177 get(self::oauth2::discovery::get),
178 )
179 .route(
180 mas_router::Webfinger::route(),
181 get(self::oauth2::webfinger::get),
182 )
183 .layer(
184 CorsLayer::new()
185 .allow_origin(Any)
186 .allow_methods(Any)
187 .allow_otel_headers([
188 AUTHORIZATION,
189 ACCEPT,
190 ACCEPT_LANGUAGE,
191 CONTENT_LANGUAGE,
192 CONTENT_TYPE,
193 ])
194 .max_age(Duration::from_secs(60 * 60)),
195 )
196}
197
198pub fn api_router<S>() -> Router<S>
199where
200 S: Clone + Send + Sync + 'static,
201 Keystore: FromRef<S>,
202 UrlBuilder: FromRef<S>,
203 BoxRepository: FromRequestParts<S>,
204 ActivityTracker: FromRequestParts<S>,
205 BoundActivityTracker: FromRequestParts<S>,
206 Encrypter: FromRef<S>,
207 reqwest::Client: FromRef<S>,
208 SiteConfig: FromRef<S>,
209 Templates: FromRef<S>,
210 Arc<dyn HomeserverConnection>: FromRef<S>,
211 BoxClock: FromRequestParts<S>,
212 BoxRng: FromRequestParts<S>,
213 Policy: FromRequestParts<S>,
214{
215 Router::new()
217 .route(
218 mas_router::OAuth2Keys::route(),
219 get(self::oauth2::keys::get),
220 )
221 .route(
222 mas_router::OidcUserinfo::route(),
223 get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
224 )
225 .route(
226 mas_router::OAuth2Introspection::route(),
227 post(self::oauth2::introspection::post),
228 )
229 .route(
230 mas_router::OAuth2Revocation::route(),
231 post(self::oauth2::revoke::post),
232 )
233 .route(
234 mas_router::OAuth2TokenEndpoint::route(),
235 post(self::oauth2::token::post),
236 )
237 .route(
238 mas_router::OAuth2RegistrationEndpoint::route(),
239 post(self::oauth2::registration::post),
240 )
241 .route(
242 mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
243 post(self::oauth2::device::authorize::post),
244 )
245 .layer(
246 CorsLayer::new()
247 .allow_origin(Any)
248 .allow_methods(Any)
249 .allow_otel_headers([
250 AUTHORIZATION,
251 ACCEPT,
252 ACCEPT_LANGUAGE,
253 CONTENT_LANGUAGE,
254 CONTENT_TYPE,
255 HeaderName::from_static("x-requested-with"),
257 ])
258 .max_age(Duration::from_secs(60 * 60)),
259 )
260}
261
262#[allow(clippy::trait_duplication_in_bounds)]
263pub fn compat_router<S>(templates: Templates) -> Router<S>
264where
265 S: Clone + Send + Sync + 'static,
266 UrlBuilder: FromRef<S>,
267 SiteConfig: FromRef<S>,
268 Arc<dyn HomeserverConnection>: FromRef<S>,
269 PasswordManager: FromRef<S>,
270 Limiter: FromRef<S>,
271 BoxRepositoryFactory: FromRef<S>,
272 BoundActivityTracker: FromRequestParts<S>,
273 RequesterFingerprint: FromRequestParts<S>,
274 BoxRepository: FromRequestParts<S>,
275 BoxClock: FromRequestParts<S>,
276 BoxRng: FromRequestParts<S>,
277 Policy: FromRequestParts<S>,
278{
279 let human_router = Router::new()
281 .route(
282 mas_router::CompatLoginSsoRedirect::route(),
283 get(self::compat::login_sso_redirect::get),
284 )
285 .route(
286 mas_router::CompatLoginSsoRedirectIdp::route(),
287 get(self::compat::login_sso_redirect::get),
288 )
289 .route(
290 mas_router::CompatLoginSsoRedirectSlash::route(),
291 get(self::compat::login_sso_redirect::get),
292 )
293 .layer(AndThenLayer::new(
294 async move |response: axum::response::Response| {
295 Ok::<_, Infallible>(recover_error(&templates, response))
296 },
297 ));
298
299 let api_router = Router::new()
301 .route(
302 mas_router::CompatLogin::route(),
303 get(self::compat::login::get).post(self::compat::login::post),
304 )
305 .route(
306 mas_router::CompatLogout::route(),
307 post(self::compat::logout::post),
308 )
309 .route(
310 mas_router::CompatLogoutAll::route(),
311 post(self::compat::logout_all::post),
312 )
313 .route(
314 mas_router::CompatRefresh::route(),
315 post(self::compat::refresh::post),
316 )
317 .layer(
318 CorsLayer::new()
319 .allow_origin(Any)
320 .allow_methods(Any)
321 .allow_otel_headers([
322 AUTHORIZATION,
323 ACCEPT,
324 ACCEPT_LANGUAGE,
325 CONTENT_LANGUAGE,
326 CONTENT_TYPE,
327 HeaderName::from_static("x-requested-with"),
328 ])
329 .max_age(Duration::from_secs(60 * 60)),
330 );
331
332 Router::new().merge(human_router).merge(api_router)
333}
334
335pub fn human_router<S>(templates: Templates) -> Router<S>
336where
337 S: Clone + Send + Sync + 'static,
338 UrlBuilder: FromRef<S>,
339 PreferredLanguage: FromRequestParts<S>,
340 BoxRepository: FromRequestParts<S>,
341 CookieJar: FromRequestParts<S>,
342 BoundActivityTracker: FromRequestParts<S>,
343 RequesterFingerprint: FromRequestParts<S>,
344 Encrypter: FromRef<S>,
345 Templates: FromRef<S>,
346 Keystore: FromRef<S>,
347 PasswordManager: FromRef<S>,
348 MetadataCache: FromRef<S>,
349 SiteConfig: FromRef<S>,
350 Limiter: FromRef<S>,
351 reqwest::Client: FromRef<S>,
352 Arc<dyn HomeserverConnection>: FromRef<S>,
353 BoxClock: FromRequestParts<S>,
354 BoxRng: FromRequestParts<S>,
355 Policy: FromRequestParts<S>,
356{
357 Router::new()
358 .route(
360 "/account",
361 get(
362 async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
363 let prefix = url_builder.prefix().unwrap_or_default();
364 let route = mas_router::Account::route();
365 let destination = if let Some(query) = query {
366 format!("{prefix}{route}?{query}")
367 } else {
368 format!("{prefix}{route}")
369 };
370
371 axum::response::Redirect::to(&destination)
372 },
373 ),
374 )
375 .route(mas_router::Account::route(), get(self::views::app::get))
376 .route(
377 mas_router::AccountWildcard::route(),
378 get(self::views::app::get),
379 )
380 .route(
381 mas_router::AccountRecoveryFinish::route(),
382 get(self::views::app::get_anonymous),
383 )
384 .route(
385 mas_router::ChangePasswordDiscovery::route(),
386 get(async |State(url_builder): State<UrlBuilder>| {
387 url_builder.redirect(&mas_router::AccountPasswordChange)
388 }),
389 )
390 .route(mas_router::Index::route(), get(self::views::index::get))
391 .route(
392 mas_router::Login::route(),
393 get(self::views::login::get).post(self::views::login::post),
394 )
395 .route(mas_router::Logout::route(), post(self::views::logout::post))
396 .route(
397 mas_router::Register::route(),
398 get(self::views::register::get),
399 )
400 .route(
401 mas_router::PasswordRegister::route(),
402 get(self::views::register::password::get).post(self::views::register::password::post),
403 )
404 .route(
405 mas_router::RegisterVerifyEmail::route(),
406 get(self::views::register::steps::verify_email::get)
407 .post(self::views::register::steps::verify_email::post),
408 )
409 .route(
410 mas_router::RegisterToken::route(),
411 get(self::views::register::steps::registration_token::get)
412 .post(self::views::register::steps::registration_token::post),
413 )
414 .route(
415 mas_router::RegisterDisplayName::route(),
416 get(self::views::register::steps::display_name::get)
417 .post(self::views::register::steps::display_name::post),
418 )
419 .route(
420 mas_router::RegisterFinish::route(),
421 get(self::views::register::steps::finish::get),
422 )
423 .route(
424 mas_router::AccountRecoveryStart::route(),
425 get(self::views::recovery::start::get).post(self::views::recovery::start::post),
426 )
427 .route(
428 mas_router::AccountRecoveryProgress::route(),
429 get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
430 )
431 .route(
432 mas_router::OAuth2AuthorizationEndpoint::route(),
433 get(self::oauth2::authorization::get),
434 )
435 .route(
436 mas_router::Consent::route(),
437 get(self::oauth2::authorization::consent::get)
438 .post(self::oauth2::authorization::consent::post),
439 )
440 .route(
441 mas_router::CompatLoginSsoComplete::route(),
442 get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
443 )
444 .route(
445 mas_router::UpstreamOAuth2Authorize::route(),
446 get(self::upstream_oauth2::authorize::get),
447 )
448 .route(
449 mas_router::UpstreamOAuth2Callback::route(),
450 get(self::upstream_oauth2::callback::handler)
451 .post(self::upstream_oauth2::callback::handler),
452 )
453 .route(
454 mas_router::UpstreamOAuth2Link::route(),
455 get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
456 )
457 .route(
458 mas_router::UpstreamOAuth2BackchannelLogout::route(),
459 post(self::upstream_oauth2::backchannel_logout::post),
460 )
461 .route(
462 mas_router::DeviceCodeLink::route(),
463 get(self::oauth2::device::link::get),
464 )
465 .route(
466 mas_router::DeviceCodeConsent::route(),
467 get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
468 )
469 .layer(AndThenLayer::new(
470 async move |response: axum::response::Response| {
471 Ok::<_, Infallible>(recover_error(&templates, response))
472 },
473 ))
474}
475
476fn recover_error(
477 templates: &Templates,
478 response: axum::response::Response,
479) -> axum::response::Response {
480 let ext = response.extensions().get::<ErrorContext>();
482 if let Some(ctx) = ext
483 && let Ok(res) = templates.render_error(ctx)
484 {
485 let (mut parts, _original_body) = response.into_parts();
486 parts.headers.remove(CONTENT_TYPE);
487 parts.headers.remove(CONTENT_LENGTH);
488 return (parts, Html(res)).into_response();
489 }
490
491 response
492}
493
494pub async fn fallback(
500 State(templates): State<Templates>,
501 OriginalUri(uri): OriginalUri,
502 method: Method,
503 version: Version,
504 PreferredLanguage(locale): PreferredLanguage,
505) -> Result<impl IntoResponse, InternalError> {
506 let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
507 let res = templates.render_not_found(&ctx)?;
510
511 Ok((StatusCode::NOT_FOUND, Html(res)))
512}