1#![allow(clippy::module_name_repetitions)]
8
9use std::{net::IpAddr, ops::Deref, sync::Arc};
10
11use async_graphql::{
12 EmptySubscription, InputObject,
13 extensions::Tracing,
14 http::{GraphQLPlaygroundConfig, MultipartOptions, playground_source},
15};
16use axum::{
17 Extension, Json,
18 body::Body,
19 extract::{RawQuery, State as AxumState},
20 http::StatusCode,
21 response::{Html, IntoResponse, Response},
22};
23use axum_extra::typed_header::TypedHeader;
24use chrono::{DateTime, Utc};
25use futures_util::TryStreamExt;
26use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
27use hyper::header::CACHE_CONTROL;
28use mas_axum_utils::{
29 InternalError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
30};
31use mas_data_model::{BrowserSession, Session, SiteConfig, User};
32use mas_matrix::HomeserverConnection;
33use mas_policy::{InstantiateError, Policy, PolicyFactory};
34use mas_router::UrlBuilder;
35use mas_storage::{
36 BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, Clock, RepositoryError, SystemClock,
37};
38use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
39use rand::{SeedableRng, thread_rng};
40use rand_chacha::ChaChaRng;
41use state::has_session_ended;
42use tracing::{Instrument, info_span};
43use ulid::Ulid;
44
45mod model;
46mod mutations;
47mod query;
48mod state;
49
50pub use self::state::{BoxState, State};
51use self::{
52 model::{CreationEvent, Node},
53 mutations::Mutation,
54 query::Query,
55};
56use crate::{
57 BoundActivityTracker, Limiter, RequesterFingerprint, impl_from_error_for_route,
58 passwords::PasswordManager,
59};
60
61#[cfg(test)]
62mod tests;
63
64#[derive(Debug, Clone)]
67pub struct ExtraRouterParameters {
68 pub undocumented_oauth2_access: bool,
69}
70
71struct GraphQLState {
72 repository_factory: BoxRepositoryFactory,
73 homeserver_connection: Arc<dyn HomeserverConnection>,
74 policy_factory: Arc<PolicyFactory>,
75 site_config: SiteConfig,
76 password_manager: PasswordManager,
77 url_builder: UrlBuilder,
78 limiter: Limiter,
79}
80
81#[async_trait::async_trait]
82impl state::State for GraphQLState {
83 async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
84 self.repository_factory.create().await
85 }
86
87 async fn policy(&self) -> Result<Policy, InstantiateError> {
88 self.policy_factory.instantiate().await
89 }
90
91 fn password_manager(&self) -> PasswordManager {
92 self.password_manager.clone()
93 }
94
95 fn site_config(&self) -> &SiteConfig {
96 &self.site_config
97 }
98
99 fn homeserver_connection(&self) -> &dyn HomeserverConnection {
100 self.homeserver_connection.as_ref()
101 }
102
103 fn url_builder(&self) -> &UrlBuilder {
104 &self.url_builder
105 }
106
107 fn limiter(&self) -> &Limiter {
108 &self.limiter
109 }
110
111 fn clock(&self) -> BoxClock {
112 let clock = SystemClock::default();
113 Box::new(clock)
114 }
115
116 fn rng(&self) -> BoxRng {
117 #[allow(clippy::disallowed_methods)]
118 let rng = thread_rng();
119
120 let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
121 Box::new(rng)
122 }
123}
124
125#[must_use]
126pub fn schema(
127 repository_factory: BoxRepositoryFactory,
128 policy_factory: &Arc<PolicyFactory>,
129 homeserver_connection: impl HomeserverConnection + 'static,
130 site_config: SiteConfig,
131 password_manager: PasswordManager,
132 url_builder: UrlBuilder,
133 limiter: Limiter,
134) -> Schema {
135 let state = GraphQLState {
136 repository_factory,
137 policy_factory: Arc::clone(policy_factory),
138 homeserver_connection: Arc::new(homeserver_connection),
139 site_config,
140 password_manager,
141 url_builder,
142 limiter,
143 };
144 let state: BoxState = Box::new(state);
145
146 schema_builder().extension(Tracing).data(state).finish()
147}
148
149fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
150 let span = info_span!(
151 "GraphQL operation",
152 "otel.name" = tracing::field::Empty,
153 "otel.kind" = "server",
154 { GRAPHQL_DOCUMENT } = request.query,
155 { GRAPHQL_OPERATION_NAME } = tracing::field::Empty,
156 );
157
158 if let Some(name) = &request.operation_name {
159 span.record("otel.name", name);
160 span.record(GRAPHQL_OPERATION_NAME, name);
161 }
162
163 span
164}
165
166#[derive(thiserror::Error, Debug)]
167pub enum RouteError {
168 #[error(transparent)]
169 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
170
171 #[error("Loading of some database objects failed")]
172 LoadFailed,
173
174 #[error("Invalid access token")]
175 InvalidToken,
176
177 #[error("Missing scope")]
178 MissingScope,
179
180 #[error(transparent)]
181 ParseRequest(#[from] async_graphql::ParseRequestError),
182}
183
184impl_from_error_for_route!(mas_storage::RepositoryError);
185
186impl IntoResponse for RouteError {
187 fn into_response(self) -> Response {
188 let event_id = sentry::capture_error(&self);
189
190 let response = match self {
191 e @ (Self::Internal(_) | Self::LoadFailed) => {
192 let error = async_graphql::Error::new_with_source(e);
193 (
194 StatusCode::INTERNAL_SERVER_ERROR,
195 Json(serde_json::json!({"errors": [error]})),
196 )
197 .into_response()
198 }
199
200 Self::InvalidToken => {
201 let error = async_graphql::Error::new("Invalid token");
202 (
203 StatusCode::UNAUTHORIZED,
204 Json(serde_json::json!({"errors": [error]})),
205 )
206 .into_response()
207 }
208
209 Self::MissingScope => {
210 let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
211 (
212 StatusCode::UNAUTHORIZED,
213 Json(serde_json::json!({"errors": [error]})),
214 )
215 .into_response()
216 }
217
218 Self::ParseRequest(e) => {
219 let error = async_graphql::Error::new_with_source(e);
220 (
221 StatusCode::BAD_REQUEST,
222 Json(serde_json::json!({"errors": [error]})),
223 )
224 .into_response()
225 }
226 };
227
228 (SentryEventID::from(event_id), response).into_response()
229 }
230}
231
232async fn get_requester(
233 undocumented_oauth2_access: bool,
234 clock: &impl Clock,
235 activity_tracker: &BoundActivityTracker,
236 mut repo: BoxRepository,
237 session_info: &SessionInfo,
238 user_agent: Option<String>,
239 token: Option<&str>,
240) -> Result<Requester, RouteError> {
241 let entity = if let Some(token) = token {
242 if !undocumented_oauth2_access {
244 return Err(RouteError::InvalidToken);
245 }
246
247 let token = repo
248 .oauth2_access_token()
249 .find_by_token(token)
250 .await?
251 .ok_or(RouteError::InvalidToken)?;
252
253 let session = repo
254 .oauth2_session()
255 .lookup(token.session_id)
256 .await?
257 .ok_or(RouteError::LoadFailed)?;
258
259 activity_tracker
260 .record_oauth2_session(clock, &session)
261 .await;
262
263 let user = if let Some(user_id) = session.user_id {
265 let user = repo
266 .user()
267 .lookup(user_id)
268 .await?
269 .ok_or(RouteError::LoadFailed)?;
270 Some(user)
271 } else {
272 None
273 };
274
275 let user_valid = user.as_ref().is_none_or(User::is_valid);
277
278 if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
279 return Err(RouteError::InvalidToken);
280 }
281
282 if !session.scope.contains("urn:mas:graphql:*") {
283 return Err(RouteError::MissingScope);
284 }
285
286 RequestingEntity::OAuth2Session(Box::new((session, user)))
287 } else {
288 let maybe_session = session_info.load_active_session(&mut repo).await?;
289
290 if let Some(session) = maybe_session.as_ref() {
291 activity_tracker
292 .record_browser_session(clock, session)
293 .await;
294 }
295
296 RequestingEntity::from(maybe_session)
297 };
298
299 let requester = Requester {
300 entity,
301 ip_address: activity_tracker.ip(),
302 user_agent,
303 };
304
305 repo.cancel().await?;
306 Ok(requester)
307}
308
309pub async fn post(
310 AxumState(schema): AxumState<Schema>,
311 Extension(ExtraRouterParameters {
312 undocumented_oauth2_access,
313 }): Extension<ExtraRouterParameters>,
314 clock: BoxClock,
315 repo: BoxRepository,
316 activity_tracker: BoundActivityTracker,
317 cookie_jar: CookieJar,
318 content_type: Option<TypedHeader<ContentType>>,
319 authorization: Option<TypedHeader<Authorization<Bearer>>>,
320 user_agent: Option<TypedHeader<headers::UserAgent>>,
321 body: Body,
322) -> Result<impl IntoResponse, RouteError> {
323 let body = body.into_data_stream();
324 let token = authorization
325 .as_ref()
326 .map(|TypedHeader(Authorization(bearer))| bearer.token());
327 let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
328 let (session_info, mut cookie_jar) = cookie_jar.session_info();
329 let requester = get_requester(
330 undocumented_oauth2_access,
331 &clock,
332 &activity_tracker,
333 repo,
334 &session_info,
335 user_agent,
336 token,
337 )
338 .await?;
339
340 let content_type = content_type.map(|TypedHeader(h)| h.to_string());
341
342 let request = async_graphql::http::receive_body(
343 content_type,
344 body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
345 .into_async_read(),
346 MultipartOptions::default(),
347 )
348 .await?
349 .data(requester); let span = span_for_graphql_request(&request);
352 let mut response = schema.execute(request).instrument(span).await;
353
354 if has_session_ended(&mut response) {
355 let session_info = session_info.mark_session_ended();
356 cookie_jar = cookie_jar.update_session_info(&session_info);
357 }
358
359 let cache_control = response
360 .cache_control
361 .value()
362 .and_then(|v| HeaderValue::from_str(&v).ok())
363 .map(|h| [(CACHE_CONTROL, h)]);
364
365 let headers = response.http_headers.clone();
366
367 Ok((headers, cache_control, cookie_jar, Json(response)))
368}
369
370pub async fn get(
371 AxumState(schema): AxumState<Schema>,
372 Extension(ExtraRouterParameters {
373 undocumented_oauth2_access,
374 }): Extension<ExtraRouterParameters>,
375 clock: BoxClock,
376 repo: BoxRepository,
377 activity_tracker: BoundActivityTracker,
378 cookie_jar: CookieJar,
379 authorization: Option<TypedHeader<Authorization<Bearer>>>,
380 user_agent: Option<TypedHeader<headers::UserAgent>>,
381 RawQuery(query): RawQuery,
382) -> Result<impl IntoResponse, InternalError> {
383 let token = authorization
384 .as_ref()
385 .map(|TypedHeader(Authorization(bearer))| bearer.token());
386 let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
387 let (session_info, mut cookie_jar) = cookie_jar.session_info();
388 let requester = get_requester(
389 undocumented_oauth2_access,
390 &clock,
391 &activity_tracker,
392 repo,
393 &session_info,
394 user_agent,
395 token,
396 )
397 .await?;
398
399 let request =
400 async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
401
402 let span = span_for_graphql_request(&request);
403 let mut response = schema.execute(request).instrument(span).await;
404
405 if has_session_ended(&mut response) {
406 let session_info = session_info.mark_session_ended();
407 cookie_jar = cookie_jar.update_session_info(&session_info);
408 }
409
410 let cache_control = response
411 .cache_control
412 .value()
413 .and_then(|v| HeaderValue::from_str(&v).ok())
414 .map(|h| [(CACHE_CONTROL, h)]);
415
416 let headers = response.http_headers.clone();
417
418 Ok((headers, cache_control, cookie_jar, Json(response)))
419}
420
421pub async fn playground() -> impl IntoResponse {
422 Html(playground_source(
423 GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
424 ))
425}
426
427pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
428pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
429
430#[must_use]
431pub fn schema_builder() -> SchemaBuilder {
432 async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
433 .register_output_type::<Node>()
434 .register_output_type::<CreationEvent>()
435}
436
437pub struct Requester {
438 entity: RequestingEntity,
439 ip_address: Option<IpAddr>,
440 user_agent: Option<String>,
441}
442
443impl Requester {
444 pub fn fingerprint(&self) -> RequesterFingerprint {
445 if let Some(ip) = self.ip_address {
446 RequesterFingerprint::new(ip)
447 } else {
448 RequesterFingerprint::EMPTY
449 }
450 }
451
452 pub fn for_policy(&self) -> mas_policy::Requester {
453 mas_policy::Requester {
454 ip_address: self.ip_address,
455 user_agent: self.user_agent.clone(),
456 }
457 }
458}
459
460impl Deref for Requester {
461 type Target = RequestingEntity;
462
463 fn deref(&self) -> &Self::Target {
464 &self.entity
465 }
466}
467
468#[derive(Debug, Clone, Default, PartialEq, Eq)]
470pub enum RequestingEntity {
471 #[default]
473 Anonymous,
474
475 BrowserSession(Box<BrowserSession>),
477
478 OAuth2Session(Box<(Session, Option<User>)>),
480}
481
482trait OwnerId {
483 fn owner_id(&self) -> Option<Ulid>;
484}
485
486impl OwnerId for User {
487 fn owner_id(&self) -> Option<Ulid> {
488 Some(self.id)
489 }
490}
491
492impl OwnerId for BrowserSession {
493 fn owner_id(&self) -> Option<Ulid> {
494 Some(self.user.id)
495 }
496}
497
498impl OwnerId for mas_data_model::UserEmail {
499 fn owner_id(&self) -> Option<Ulid> {
500 Some(self.user_id)
501 }
502}
503
504impl OwnerId for Session {
505 fn owner_id(&self) -> Option<Ulid> {
506 self.user_id
507 }
508}
509
510impl OwnerId for mas_data_model::CompatSession {
511 fn owner_id(&self) -> Option<Ulid> {
512 Some(self.user_id)
513 }
514}
515
516impl OwnerId for mas_data_model::UpstreamOAuthLink {
517 fn owner_id(&self) -> Option<Ulid> {
518 self.user_id
519 }
520}
521
522pub struct UserId(Ulid);
524
525impl OwnerId for UserId {
526 fn owner_id(&self) -> Option<Ulid> {
527 Some(self.0)
528 }
529}
530
531impl RequestingEntity {
532 fn browser_session(&self) -> Option<&BrowserSession> {
533 match self {
534 Self::BrowserSession(session) => Some(session),
535 Self::OAuth2Session(_) | Self::Anonymous => None,
536 }
537 }
538
539 fn user(&self) -> Option<&User> {
540 match self {
541 Self::BrowserSession(session) => Some(&session.user),
542 Self::OAuth2Session(tuple) => tuple.1.as_ref(),
543 Self::Anonymous => None,
544 }
545 }
546
547 fn oauth2_session(&self) -> Option<&Session> {
548 match self {
549 Self::OAuth2Session(tuple) => Some(&tuple.0),
550 Self::BrowserSession(_) | Self::Anonymous => None,
551 }
552 }
553
554 fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
556 if self.is_admin() {
558 return true;
559 }
560
561 let Some(owner_id) = resource.owner_id() else {
563 return false;
564 };
565
566 let Some(user) = self.user() else {
567 return false;
568 };
569
570 user.id == owner_id
571 }
572
573 fn is_admin(&self) -> bool {
574 match self {
575 Self::OAuth2Session(tuple) => {
576 tuple.0.scope.contains("urn:mas:admin")
579 }
580 Self::BrowserSession(_) | Self::Anonymous => false,
581 }
582 }
583
584 fn is_unauthenticated(&self) -> bool {
585 matches!(self, Self::Anonymous)
586 }
587}
588
589impl From<BrowserSession> for RequestingEntity {
590 fn from(session: BrowserSession) -> Self {
591 Self::BrowserSession(Box::new(session))
592 }
593}
594
595impl<T> From<Option<T>> for RequestingEntity
596where
597 T: Into<RequestingEntity>,
598{
599 fn from(session: Option<T>) -> Self {
600 session.map(Into::into).unwrap_or_default()
601 }
602}
603
604#[derive(InputObject, Default, Clone, Copy)]
606pub struct DateFilter {
607 after: Option<DateTime<Utc>>,
609
610 before: Option<DateTime<Utc>>,
612}