1#![deny(clippy::future_not_send)]
8#![allow(
9 clippy::unused_async,
11 clippy::too_many_arguments,
13 clippy::let_with_type_underscore,
16)]
17
18use std::{
19 convert::Infallible,
20 sync::{Arc, LazyLock},
21 time::Duration,
22};
23
24use axum::{
25 Extension, Router,
26 extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
27 http::Method,
28 response::{Html, IntoResponse},
29 routing::{get, post},
30};
31use headers::HeaderName;
32use hyper::{
33 StatusCode, Version,
34 header::{
35 ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
36 },
37};
38use mas_axum_utils::{InternalError, cookies::CookieJar};
39use mas_data_model::SiteConfig;
40use mas_http::CorsLayerExt;
41use mas_keystore::{Encrypter, Keystore};
42use mas_matrix::HomeserverConnection;
43use mas_policy::Policy;
44use mas_router::{Route, UrlBuilder};
45use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng};
46use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
47use opentelemetry::metrics::Meter;
48use sqlx::PgPool;
49use tower::util::AndThenLayer;
50use tower_http::cors::{Any, CorsLayer};
51
52use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
53
54mod admin;
55mod compat;
56mod graphql;
57mod health;
58mod oauth2;
59pub mod passwords;
60pub mod upstream_oauth2;
61mod views;
62
63mod activity_tracker;
64mod captcha;
65mod preferred_language;
66mod rate_limit;
67mod session;
68#[cfg(test)]
69mod test_utils;
70
71static METER: LazyLock<Meter> = LazyLock::new(|| {
72 let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
73 .with_version(env!("CARGO_PKG_VERSION"))
74 .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
75 .build();
76
77 opentelemetry::global::meter_with_scope(scope)
78});
79
80#[macro_export]
83macro_rules! impl_from_error_for_route {
84 ($route_error:ty : $error:ty) => {
85 impl From<$error> for $route_error {
86 fn from(e: $error) -> Self {
87 Self::Internal(Box::new(e))
88 }
89 }
90 };
91 ($error:ty) => {
92 impl_from_error_for_route!(self::RouteError: $error);
93 };
94}
95
96pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
97
98pub use self::{
99 activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
100 admin::router as admin_api_router,
101 graphql::{
102 Schema as GraphQLSchema, schema as graphql_schema, schema_builder as graphql_schema_builder,
103 },
104 preferred_language::PreferredLanguage,
105 rate_limit::{Limiter, RequesterFingerprint},
106 upstream_oauth2::cache::MetadataCache,
107};
108
109pub fn healthcheck_router<S>() -> Router<S>
110where
111 S: Clone + Send + Sync + 'static,
112 PgPool: FromRef<S>,
113{
114 Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
115}
116
117pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
118where
119 S: Clone + Send + Sync + 'static,
120 graphql::Schema: FromRef<S>,
121 BoundActivityTracker: FromRequestParts<S>,
122 BoxRepository: FromRequestParts<S>,
123 BoxClock: FromRequestParts<S>,
124 Encrypter: FromRef<S>,
125 CookieJar: FromRequestParts<S>,
126 Limiter: FromRef<S>,
127 RequesterFingerprint: FromRequestParts<S>,
128{
129 let mut router = Router::new()
130 .route(
131 mas_router::GraphQL::route(),
132 get(self::graphql::get).post(self::graphql::post),
133 )
134 .layer(Extension(ExtraRouterParameters {
137 undocumented_oauth2_access,
138 }))
139 .layer(
140 CorsLayer::new()
141 .allow_origin(Any)
142 .allow_methods(Any)
143 .allow_otel_headers([
144 AUTHORIZATION,
145 ACCEPT,
146 ACCEPT_LANGUAGE,
147 CONTENT_LANGUAGE,
148 CONTENT_TYPE,
149 ]),
150 );
151
152 if playground {
153 router = router.route(
154 mas_router::GraphQLPlayground::route(),
155 get(self::graphql::playground),
156 );
157 }
158
159 router
160}
161
162pub fn discovery_router<S>() -> Router<S>
163where
164 S: Clone + Send + Sync + 'static,
165 Keystore: FromRef<S>,
166 SiteConfig: FromRef<S>,
167 UrlBuilder: FromRef<S>,
168 BoxClock: FromRequestParts<S>,
169 BoxRng: FromRequestParts<S>,
170{
171 Router::new()
172 .route(
173 mas_router::OidcConfiguration::route(),
174 get(self::oauth2::discovery::get),
175 )
176 .route(
177 mas_router::Webfinger::route(),
178 get(self::oauth2::webfinger::get),
179 )
180 .layer(
181 CorsLayer::new()
182 .allow_origin(Any)
183 .allow_methods(Any)
184 .allow_otel_headers([
185 AUTHORIZATION,
186 ACCEPT,
187 ACCEPT_LANGUAGE,
188 CONTENT_LANGUAGE,
189 CONTENT_TYPE,
190 ])
191 .max_age(Duration::from_secs(60 * 60)),
192 )
193}
194
195pub fn api_router<S>() -> Router<S>
196where
197 S: Clone + Send + Sync + 'static,
198 Keystore: FromRef<S>,
199 UrlBuilder: FromRef<S>,
200 BoxRepository: FromRequestParts<S>,
201 ActivityTracker: FromRequestParts<S>,
202 BoundActivityTracker: FromRequestParts<S>,
203 Encrypter: FromRef<S>,
204 reqwest::Client: FromRef<S>,
205 SiteConfig: FromRef<S>,
206 Templates: FromRef<S>,
207 Arc<dyn HomeserverConnection>: FromRef<S>,
208 BoxClock: FromRequestParts<S>,
209 BoxRng: FromRequestParts<S>,
210 Policy: FromRequestParts<S>,
211{
212 Router::new()
214 .route(
215 mas_router::OAuth2Keys::route(),
216 get(self::oauth2::keys::get),
217 )
218 .route(
219 mas_router::OidcUserinfo::route(),
220 get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
221 )
222 .route(
223 mas_router::OAuth2Introspection::route(),
224 post(self::oauth2::introspection::post),
225 )
226 .route(
227 mas_router::OAuth2Revocation::route(),
228 post(self::oauth2::revoke::post),
229 )
230 .route(
231 mas_router::OAuth2TokenEndpoint::route(),
232 post(self::oauth2::token::post),
233 )
234 .route(
235 mas_router::OAuth2RegistrationEndpoint::route(),
236 post(self::oauth2::registration::post),
237 )
238 .route(
239 mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
240 post(self::oauth2::device::authorize::post),
241 )
242 .layer(
243 CorsLayer::new()
244 .allow_origin(Any)
245 .allow_methods(Any)
246 .allow_otel_headers([
247 AUTHORIZATION,
248 ACCEPT,
249 ACCEPT_LANGUAGE,
250 CONTENT_LANGUAGE,
251 CONTENT_TYPE,
252 HeaderName::from_static("x-requested-with"),
254 ])
255 .max_age(Duration::from_secs(60 * 60)),
256 )
257}
258
259#[allow(clippy::trait_duplication_in_bounds)]
260pub fn compat_router<S>(templates: Templates) -> Router<S>
261where
262 S: Clone + Send + Sync + 'static,
263 UrlBuilder: FromRef<S>,
264 SiteConfig: FromRef<S>,
265 Arc<dyn HomeserverConnection>: FromRef<S>,
266 PasswordManager: FromRef<S>,
267 Limiter: FromRef<S>,
268 BoxRepositoryFactory: FromRef<S>,
269 BoundActivityTracker: FromRequestParts<S>,
270 RequesterFingerprint: FromRequestParts<S>,
271 BoxRepository: FromRequestParts<S>,
272 BoxClock: FromRequestParts<S>,
273 BoxRng: FromRequestParts<S>,
274{
275 let human_router = Router::new()
277 .route(
278 mas_router::CompatLoginSsoRedirect::route(),
279 get(self::compat::login_sso_redirect::get),
280 )
281 .route(
282 mas_router::CompatLoginSsoRedirectIdp::route(),
283 get(self::compat::login_sso_redirect::get),
284 )
285 .route(
286 mas_router::CompatLoginSsoRedirectSlash::route(),
287 get(self::compat::login_sso_redirect::get),
288 )
289 .layer(AndThenLayer::new(
290 async move |response: axum::response::Response| {
291 Ok::<_, Infallible>(recover_error(&templates, response))
292 },
293 ));
294
295 let api_router = Router::new()
297 .route(
298 mas_router::CompatLogin::route(),
299 get(self::compat::login::get).post(self::compat::login::post),
300 )
301 .route(
302 mas_router::CompatLogout::route(),
303 post(self::compat::logout::post),
304 )
305 .route(
306 mas_router::CompatLogoutAll::route(),
307 post(self::compat::logout_all::post),
308 )
309 .route(
310 mas_router::CompatRefresh::route(),
311 post(self::compat::refresh::post),
312 )
313 .layer(
314 CorsLayer::new()
315 .allow_origin(Any)
316 .allow_methods(Any)
317 .allow_otel_headers([
318 AUTHORIZATION,
319 ACCEPT,
320 ACCEPT_LANGUAGE,
321 CONTENT_LANGUAGE,
322 CONTENT_TYPE,
323 HeaderName::from_static("x-requested-with"),
324 ])
325 .max_age(Duration::from_secs(60 * 60)),
326 );
327
328 Router::new().merge(human_router).merge(api_router)
329}
330
331pub fn human_router<S>(templates: Templates) -> Router<S>
332where
333 S: Clone + Send + Sync + 'static,
334 UrlBuilder: FromRef<S>,
335 PreferredLanguage: FromRequestParts<S>,
336 BoxRepository: FromRequestParts<S>,
337 CookieJar: FromRequestParts<S>,
338 BoundActivityTracker: FromRequestParts<S>,
339 RequesterFingerprint: FromRequestParts<S>,
340 Encrypter: FromRef<S>,
341 Templates: FromRef<S>,
342 Keystore: FromRef<S>,
343 PasswordManager: FromRef<S>,
344 MetadataCache: FromRef<S>,
345 SiteConfig: FromRef<S>,
346 Limiter: FromRef<S>,
347 reqwest::Client: FromRef<S>,
348 Arc<dyn HomeserverConnection>: FromRef<S>,
349 BoxClock: FromRequestParts<S>,
350 BoxRng: FromRequestParts<S>,
351 Policy: FromRequestParts<S>,
352{
353 Router::new()
354 .route(
356 "/account",
357 get(
358 async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
359 let prefix = url_builder.prefix().unwrap_or_default();
360 let route = mas_router::Account::route();
361 let destination = if let Some(query) = query {
362 format!("{prefix}{route}?{query}")
363 } else {
364 format!("{prefix}{route}")
365 };
366
367 axum::response::Redirect::to(&destination)
368 },
369 ),
370 )
371 .route(mas_router::Account::route(), get(self::views::app::get))
372 .route(
373 mas_router::AccountWildcard::route(),
374 get(self::views::app::get),
375 )
376 .route(
377 mas_router::AccountRecoveryFinish::route(),
378 get(self::views::app::get_anonymous),
379 )
380 .route(
381 mas_router::ChangePasswordDiscovery::route(),
382 get(async |State(url_builder): State<UrlBuilder>| {
383 url_builder.redirect(&mas_router::AccountPasswordChange)
384 }),
385 )
386 .route(mas_router::Index::route(), get(self::views::index::get))
387 .route(
388 mas_router::Login::route(),
389 get(self::views::login::get).post(self::views::login::post),
390 )
391 .route(mas_router::Logout::route(), post(self::views::logout::post))
392 .route(
393 mas_router::Register::route(),
394 get(self::views::register::get),
395 )
396 .route(
397 mas_router::PasswordRegister::route(),
398 get(self::views::register::password::get).post(self::views::register::password::post),
399 )
400 .route(
401 mas_router::RegisterVerifyEmail::route(),
402 get(self::views::register::steps::verify_email::get)
403 .post(self::views::register::steps::verify_email::post),
404 )
405 .route(
406 mas_router::RegisterToken::route(),
407 get(self::views::register::steps::registration_token::get)
408 .post(self::views::register::steps::registration_token::post),
409 )
410 .route(
411 mas_router::RegisterDisplayName::route(),
412 get(self::views::register::steps::display_name::get)
413 .post(self::views::register::steps::display_name::post),
414 )
415 .route(
416 mas_router::RegisterFinish::route(),
417 get(self::views::register::steps::finish::get),
418 )
419 .route(
420 mas_router::AccountRecoveryStart::route(),
421 get(self::views::recovery::start::get).post(self::views::recovery::start::post),
422 )
423 .route(
424 mas_router::AccountRecoveryProgress::route(),
425 get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
426 )
427 .route(
428 mas_router::OAuth2AuthorizationEndpoint::route(),
429 get(self::oauth2::authorization::get),
430 )
431 .route(
432 mas_router::Consent::route(),
433 get(self::oauth2::authorization::consent::get)
434 .post(self::oauth2::authorization::consent::post),
435 )
436 .route(
437 mas_router::CompatLoginSsoComplete::route(),
438 get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
439 )
440 .route(
441 mas_router::UpstreamOAuth2Authorize::route(),
442 get(self::upstream_oauth2::authorize::get),
443 )
444 .route(
445 mas_router::UpstreamOAuth2Callback::route(),
446 get(self::upstream_oauth2::callback::handler)
447 .post(self::upstream_oauth2::callback::handler),
448 )
449 .route(
450 mas_router::UpstreamOAuth2Link::route(),
451 get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
452 )
453 .route(
454 mas_router::UpstreamOAuth2BackchannelLogout::route(),
455 post(self::upstream_oauth2::backchannel_logout::post),
456 )
457 .route(
458 mas_router::DeviceCodeLink::route(),
459 get(self::oauth2::device::link::get),
460 )
461 .route(
462 mas_router::DeviceCodeConsent::route(),
463 get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
464 )
465 .layer(AndThenLayer::new(
466 async move |response: axum::response::Response| {
467 Ok::<_, Infallible>(recover_error(&templates, response))
468 },
469 ))
470}
471
472fn recover_error(
473 templates: &Templates,
474 response: axum::response::Response,
475) -> axum::response::Response {
476 let ext = response.extensions().get::<ErrorContext>();
478 if let Some(ctx) = ext {
479 if let Ok(res) = templates.render_error(ctx) {
480 let (mut parts, _original_body) = response.into_parts();
481 parts.headers.remove(CONTENT_TYPE);
482 parts.headers.remove(CONTENT_LENGTH);
483 return (parts, Html(res)).into_response();
484 }
485 }
486
487 response
488}
489
490pub async fn fallback(
496 State(templates): State<Templates>,
497 OriginalUri(uri): OriginalUri,
498 method: Method,
499 version: Version,
500 PreferredLanguage(locale): PreferredLanguage,
501) -> Result<impl IntoResponse, InternalError> {
502 let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
503 let res = templates.render_not_found(&ctx)?;
506
507 Ok((StatusCode::NOT_FOUND, Html(res)))
508}