mas_handlers/graphql/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7#![allow(clippy::module_name_repetitions)]
8
9use std::{net::IpAddr, ops::Deref, sync::Arc};
10
11use async_graphql::{
12    EmptySubscription, InputObject,
13    extensions::Tracing,
14    http::{GraphQLPlaygroundConfig, MultipartOptions, playground_source},
15};
16use axum::{
17    Extension, Json,
18    body::Body,
19    extract::{RawQuery, State as AxumState},
20    http::StatusCode,
21    response::{Html, IntoResponse, Response},
22};
23use axum_extra::typed_header::TypedHeader;
24use chrono::{DateTime, Utc};
25use futures_util::TryStreamExt;
26use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
27use hyper::header::CACHE_CONTROL;
28use mas_axum_utils::{
29    InternalError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
30};
31use mas_data_model::{BrowserSession, Session, SiteConfig, User};
32use mas_matrix::HomeserverConnection;
33use mas_policy::{InstantiateError, Policy, PolicyFactory};
34use mas_router::UrlBuilder;
35use mas_storage::{
36    BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, Clock, RepositoryError, SystemClock,
37};
38use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
39use rand::{SeedableRng, thread_rng};
40use rand_chacha::ChaChaRng;
41use state::has_session_ended;
42use tracing::{Instrument, info_span};
43use ulid::Ulid;
44
45mod model;
46mod mutations;
47mod query;
48mod state;
49
50pub use self::state::{BoxState, State};
51use self::{
52    model::{CreationEvent, Node},
53    mutations::Mutation,
54    query::Query,
55};
56use crate::{
57    BoundActivityTracker, Limiter, RequesterFingerprint, impl_from_error_for_route,
58    passwords::PasswordManager,
59};
60
61#[cfg(test)]
62mod tests;
63
64/// Extra parameters we get from the listener configuration, because they are
65/// per-listener options. We pass them through request extensions.
66#[derive(Debug, Clone)]
67pub struct ExtraRouterParameters {
68    pub undocumented_oauth2_access: bool,
69}
70
71struct GraphQLState {
72    repository_factory: BoxRepositoryFactory,
73    homeserver_connection: Arc<dyn HomeserverConnection>,
74    policy_factory: Arc<PolicyFactory>,
75    site_config: SiteConfig,
76    password_manager: PasswordManager,
77    url_builder: UrlBuilder,
78    limiter: Limiter,
79}
80
81#[async_trait::async_trait]
82impl state::State for GraphQLState {
83    async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
84        self.repository_factory.create().await
85    }
86
87    async fn policy(&self) -> Result<Policy, InstantiateError> {
88        self.policy_factory.instantiate().await
89    }
90
91    fn password_manager(&self) -> PasswordManager {
92        self.password_manager.clone()
93    }
94
95    fn site_config(&self) -> &SiteConfig {
96        &self.site_config
97    }
98
99    fn homeserver_connection(&self) -> &dyn HomeserverConnection {
100        self.homeserver_connection.as_ref()
101    }
102
103    fn url_builder(&self) -> &UrlBuilder {
104        &self.url_builder
105    }
106
107    fn limiter(&self) -> &Limiter {
108        &self.limiter
109    }
110
111    fn clock(&self) -> BoxClock {
112        let clock = SystemClock::default();
113        Box::new(clock)
114    }
115
116    fn rng(&self) -> BoxRng {
117        #[allow(clippy::disallowed_methods)]
118        let rng = thread_rng();
119
120        let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
121        Box::new(rng)
122    }
123}
124
125#[must_use]
126pub fn schema(
127    repository_factory: BoxRepositoryFactory,
128    policy_factory: &Arc<PolicyFactory>,
129    homeserver_connection: impl HomeserverConnection + 'static,
130    site_config: SiteConfig,
131    password_manager: PasswordManager,
132    url_builder: UrlBuilder,
133    limiter: Limiter,
134) -> Schema {
135    let state = GraphQLState {
136        repository_factory,
137        policy_factory: Arc::clone(policy_factory),
138        homeserver_connection: Arc::new(homeserver_connection),
139        site_config,
140        password_manager,
141        url_builder,
142        limiter,
143    };
144    let state: BoxState = Box::new(state);
145
146    schema_builder().extension(Tracing).data(state).finish()
147}
148
149fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
150    let span = info_span!(
151        "GraphQL operation",
152        "otel.name" = tracing::field::Empty,
153        "otel.kind" = "server",
154        { GRAPHQL_DOCUMENT } = request.query,
155        { GRAPHQL_OPERATION_NAME } = tracing::field::Empty,
156    );
157
158    if let Some(name) = &request.operation_name {
159        span.record("otel.name", name);
160        span.record(GRAPHQL_OPERATION_NAME, name);
161    }
162
163    span
164}
165
166#[derive(thiserror::Error, Debug)]
167pub enum RouteError {
168    #[error(transparent)]
169    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
170
171    #[error("Loading of some database objects failed")]
172    LoadFailed,
173
174    #[error("Invalid access token")]
175    InvalidToken,
176
177    #[error("Missing scope")]
178    MissingScope,
179
180    #[error(transparent)]
181    ParseRequest(#[from] async_graphql::ParseRequestError),
182}
183
184impl_from_error_for_route!(mas_storage::RepositoryError);
185
186impl IntoResponse for RouteError {
187    fn into_response(self) -> Response {
188        let event_id = sentry::capture_error(&self);
189
190        let response = match self {
191            e @ (Self::Internal(_) | Self::LoadFailed) => {
192                let error = async_graphql::Error::new_with_source(e);
193                (
194                    StatusCode::INTERNAL_SERVER_ERROR,
195                    Json(serde_json::json!({"errors": [error]})),
196                )
197                    .into_response()
198            }
199
200            Self::InvalidToken => {
201                let error = async_graphql::Error::new("Invalid token");
202                (
203                    StatusCode::UNAUTHORIZED,
204                    Json(serde_json::json!({"errors": [error]})),
205                )
206                    .into_response()
207            }
208
209            Self::MissingScope => {
210                let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
211                (
212                    StatusCode::UNAUTHORIZED,
213                    Json(serde_json::json!({"errors": [error]})),
214                )
215                    .into_response()
216            }
217
218            Self::ParseRequest(e) => {
219                let error = async_graphql::Error::new_with_source(e);
220                (
221                    StatusCode::BAD_REQUEST,
222                    Json(serde_json::json!({"errors": [error]})),
223                )
224                    .into_response()
225            }
226        };
227
228        (SentryEventID::from(event_id), response).into_response()
229    }
230}
231
232async fn get_requester(
233    undocumented_oauth2_access: bool,
234    clock: &impl Clock,
235    activity_tracker: &BoundActivityTracker,
236    mut repo: BoxRepository,
237    session_info: &SessionInfo,
238    user_agent: Option<String>,
239    token: Option<&str>,
240) -> Result<Requester, RouteError> {
241    let entity = if let Some(token) = token {
242        // If we haven't enabled undocumented_oauth2_access on the listener, we bail out
243        if !undocumented_oauth2_access {
244            return Err(RouteError::InvalidToken);
245        }
246
247        let token = repo
248            .oauth2_access_token()
249            .find_by_token(token)
250            .await?
251            .ok_or(RouteError::InvalidToken)?;
252
253        let session = repo
254            .oauth2_session()
255            .lookup(token.session_id)
256            .await?
257            .ok_or(RouteError::LoadFailed)?;
258
259        activity_tracker
260            .record_oauth2_session(clock, &session)
261            .await;
262
263        // Load the user if there is one
264        let user = if let Some(user_id) = session.user_id {
265            let user = repo
266                .user()
267                .lookup(user_id)
268                .await?
269                .ok_or(RouteError::LoadFailed)?;
270            Some(user)
271        } else {
272            None
273        };
274
275        // If there is a user for this session, check that it is not locked
276        let user_valid = user.as_ref().is_none_or(User::is_valid);
277
278        if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
279            return Err(RouteError::InvalidToken);
280        }
281
282        if !session.scope.contains("urn:mas:graphql:*") {
283            return Err(RouteError::MissingScope);
284        }
285
286        RequestingEntity::OAuth2Session(Box::new((session, user)))
287    } else {
288        let maybe_session = session_info.load_active_session(&mut repo).await?;
289
290        if let Some(session) = maybe_session.as_ref() {
291            activity_tracker
292                .record_browser_session(clock, session)
293                .await;
294        }
295
296        RequestingEntity::from(maybe_session)
297    };
298
299    let requester = Requester {
300        entity,
301        ip_address: activity_tracker.ip(),
302        user_agent,
303    };
304
305    repo.cancel().await?;
306    Ok(requester)
307}
308
309pub async fn post(
310    AxumState(schema): AxumState<Schema>,
311    Extension(ExtraRouterParameters {
312        undocumented_oauth2_access,
313    }): Extension<ExtraRouterParameters>,
314    clock: BoxClock,
315    repo: BoxRepository,
316    activity_tracker: BoundActivityTracker,
317    cookie_jar: CookieJar,
318    content_type: Option<TypedHeader<ContentType>>,
319    authorization: Option<TypedHeader<Authorization<Bearer>>>,
320    user_agent: Option<TypedHeader<headers::UserAgent>>,
321    body: Body,
322) -> Result<impl IntoResponse, RouteError> {
323    let body = body.into_data_stream();
324    let token = authorization
325        .as_ref()
326        .map(|TypedHeader(Authorization(bearer))| bearer.token());
327    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
328    let (session_info, mut cookie_jar) = cookie_jar.session_info();
329    let requester = get_requester(
330        undocumented_oauth2_access,
331        &clock,
332        &activity_tracker,
333        repo,
334        &session_info,
335        user_agent,
336        token,
337    )
338    .await?;
339
340    let content_type = content_type.map(|TypedHeader(h)| h.to_string());
341
342    let request = async_graphql::http::receive_body(
343        content_type,
344        body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
345            .into_async_read(),
346        MultipartOptions::default(),
347    )
348    .await?
349    .data(requester); // XXX: this should probably return another error response?
350
351    let span = span_for_graphql_request(&request);
352    let mut response = schema.execute(request).instrument(span).await;
353
354    if has_session_ended(&mut response) {
355        let session_info = session_info.mark_session_ended();
356        cookie_jar = cookie_jar.update_session_info(&session_info);
357    }
358
359    let cache_control = response
360        .cache_control
361        .value()
362        .and_then(|v| HeaderValue::from_str(&v).ok())
363        .map(|h| [(CACHE_CONTROL, h)]);
364
365    let headers = response.http_headers.clone();
366
367    Ok((headers, cache_control, cookie_jar, Json(response)))
368}
369
370pub async fn get(
371    AxumState(schema): AxumState<Schema>,
372    Extension(ExtraRouterParameters {
373        undocumented_oauth2_access,
374    }): Extension<ExtraRouterParameters>,
375    clock: BoxClock,
376    repo: BoxRepository,
377    activity_tracker: BoundActivityTracker,
378    cookie_jar: CookieJar,
379    authorization: Option<TypedHeader<Authorization<Bearer>>>,
380    user_agent: Option<TypedHeader<headers::UserAgent>>,
381    RawQuery(query): RawQuery,
382) -> Result<impl IntoResponse, InternalError> {
383    let token = authorization
384        .as_ref()
385        .map(|TypedHeader(Authorization(bearer))| bearer.token());
386    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
387    let (session_info, mut cookie_jar) = cookie_jar.session_info();
388    let requester = get_requester(
389        undocumented_oauth2_access,
390        &clock,
391        &activity_tracker,
392        repo,
393        &session_info,
394        user_agent,
395        token,
396    )
397    .await?;
398
399    let request =
400        async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
401
402    let span = span_for_graphql_request(&request);
403    let mut response = schema.execute(request).instrument(span).await;
404
405    if has_session_ended(&mut response) {
406        let session_info = session_info.mark_session_ended();
407        cookie_jar = cookie_jar.update_session_info(&session_info);
408    }
409
410    let cache_control = response
411        .cache_control
412        .value()
413        .and_then(|v| HeaderValue::from_str(&v).ok())
414        .map(|h| [(CACHE_CONTROL, h)]);
415
416    let headers = response.http_headers.clone();
417
418    Ok((headers, cache_control, cookie_jar, Json(response)))
419}
420
421pub async fn playground() -> impl IntoResponse {
422    Html(playground_source(
423        GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
424    ))
425}
426
427pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
428pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
429
430#[must_use]
431pub fn schema_builder() -> SchemaBuilder {
432    async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
433        .register_output_type::<Node>()
434        .register_output_type::<CreationEvent>()
435}
436
437pub struct Requester {
438    entity: RequestingEntity,
439    ip_address: Option<IpAddr>,
440    user_agent: Option<String>,
441}
442
443impl Requester {
444    pub fn fingerprint(&self) -> RequesterFingerprint {
445        if let Some(ip) = self.ip_address {
446            RequesterFingerprint::new(ip)
447        } else {
448            RequesterFingerprint::EMPTY
449        }
450    }
451
452    pub fn for_policy(&self) -> mas_policy::Requester {
453        mas_policy::Requester {
454            ip_address: self.ip_address,
455            user_agent: self.user_agent.clone(),
456        }
457    }
458}
459
460impl Deref for Requester {
461    type Target = RequestingEntity;
462
463    fn deref(&self) -> &Self::Target {
464        &self.entity
465    }
466}
467
468/// The identity of the requester.
469#[derive(Debug, Clone, Default, PartialEq, Eq)]
470pub enum RequestingEntity {
471    /// The requester presented no authentication information.
472    #[default]
473    Anonymous,
474
475    /// The requester is a browser session, stored in a cookie.
476    BrowserSession(Box<BrowserSession>),
477
478    /// The requester is a `OAuth2` session, with an access token.
479    OAuth2Session(Box<(Session, Option<User>)>),
480}
481
482trait OwnerId {
483    fn owner_id(&self) -> Option<Ulid>;
484}
485
486impl OwnerId for User {
487    fn owner_id(&self) -> Option<Ulid> {
488        Some(self.id)
489    }
490}
491
492impl OwnerId for BrowserSession {
493    fn owner_id(&self) -> Option<Ulid> {
494        Some(self.user.id)
495    }
496}
497
498impl OwnerId for mas_data_model::UserEmail {
499    fn owner_id(&self) -> Option<Ulid> {
500        Some(self.user_id)
501    }
502}
503
504impl OwnerId for Session {
505    fn owner_id(&self) -> Option<Ulid> {
506        self.user_id
507    }
508}
509
510impl OwnerId for mas_data_model::CompatSession {
511    fn owner_id(&self) -> Option<Ulid> {
512        Some(self.user_id)
513    }
514}
515
516impl OwnerId for mas_data_model::UpstreamOAuthLink {
517    fn owner_id(&self) -> Option<Ulid> {
518        self.user_id
519    }
520}
521
522/// A dumb wrapper around a `Ulid` to implement `OwnerId` for it.
523pub struct UserId(Ulid);
524
525impl OwnerId for UserId {
526    fn owner_id(&self) -> Option<Ulid> {
527        Some(self.0)
528    }
529}
530
531impl RequestingEntity {
532    fn browser_session(&self) -> Option<&BrowserSession> {
533        match self {
534            Self::BrowserSession(session) => Some(session),
535            Self::OAuth2Session(_) | Self::Anonymous => None,
536        }
537    }
538
539    fn user(&self) -> Option<&User> {
540        match self {
541            Self::BrowserSession(session) => Some(&session.user),
542            Self::OAuth2Session(tuple) => tuple.1.as_ref(),
543            Self::Anonymous => None,
544        }
545    }
546
547    fn oauth2_session(&self) -> Option<&Session> {
548        match self {
549            Self::OAuth2Session(tuple) => Some(&tuple.0),
550            Self::BrowserSession(_) | Self::Anonymous => None,
551        }
552    }
553
554    /// Returns true if the requester can access the resource.
555    fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
556        // If the requester is an admin, they can do anything.
557        if self.is_admin() {
558            return true;
559        }
560
561        // Otherwise, they must be the owner of the resource.
562        let Some(owner_id) = resource.owner_id() else {
563            return false;
564        };
565
566        let Some(user) = self.user() else {
567            return false;
568        };
569
570        user.id == owner_id
571    }
572
573    fn is_admin(&self) -> bool {
574        match self {
575            Self::OAuth2Session(tuple) => {
576                // TODO: is this the right scope?
577                // This has to be in sync with the policy
578                tuple.0.scope.contains("urn:mas:admin")
579            }
580            Self::BrowserSession(_) | Self::Anonymous => false,
581        }
582    }
583
584    fn is_unauthenticated(&self) -> bool {
585        matches!(self, Self::Anonymous)
586    }
587}
588
589impl From<BrowserSession> for RequestingEntity {
590    fn from(session: BrowserSession) -> Self {
591        Self::BrowserSession(Box::new(session))
592    }
593}
594
595impl<T> From<Option<T>> for RequestingEntity
596where
597    T: Into<RequestingEntity>,
598{
599    fn from(session: Option<T>) -> Self {
600        session.map(Into::into).unwrap_or_default()
601    }
602}
603
604/// A filter for dates, with a lower bound and an upper bound
605#[derive(InputObject, Default, Clone, Copy)]
606pub struct DateFilter {
607    /// The lower bound of the date range
608    after: Option<DateTime<Utc>>,
609
610    /// The upper bound of the date range
611    before: Option<DateTime<Utc>>,
612}