mas_handlers/oauth2/device/
consent.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{sync::Arc, time::Duration};
8
9use anyhow::Context;
10use axum::{
11    Form,
12    extract::{Path, State},
13    response::{Html, IntoResponse, Response},
14};
15use axum_extra::TypedHeader;
16use mas_axum_utils::{
17    InternalError,
18    cookies::CookieJar,
19    csrf::{CsrfExt, ProtectedForm},
20};
21use mas_data_model::{BoxClock, BoxRng, MatrixUser};
22use mas_matrix::HomeserverConnection;
23use mas_policy::Policy;
24use mas_router::UrlBuilder;
25use mas_storage::BoxRepository;
26use mas_templates::{DeviceConsentContext, PolicyViolationContext, TemplateContext, Templates};
27use serde::Deserialize;
28use tracing::warn;
29use ulid::Ulid;
30
31use crate::{
32    BoundActivityTracker, PreferredLanguage,
33    session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
34};
35
36#[derive(Deserialize, Debug)]
37#[serde(rename_all = "lowercase")]
38enum Action {
39    Consent,
40    Reject,
41}
42
43#[derive(Deserialize, Debug)]
44pub(crate) struct ConsentForm {
45    action: Action,
46}
47
48#[tracing::instrument(name = "handlers.oauth2.device.consent.get", skip_all)]
49pub(crate) async fn get(
50    mut rng: BoxRng,
51    clock: BoxClock,
52    PreferredLanguage(locale): PreferredLanguage,
53    State(templates): State<Templates>,
54    State(url_builder): State<UrlBuilder>,
55    State(homeserver): State<Arc<dyn HomeserverConnection>>,
56    mut repo: BoxRepository,
57    mut policy: Policy,
58    activity_tracker: BoundActivityTracker,
59    user_agent: Option<TypedHeader<headers::UserAgent>>,
60    cookie_jar: CookieJar,
61    Path(grant_id): Path<Ulid>,
62) -> Result<Response, InternalError> {
63    let (cookie_jar, maybe_session) = match load_session_or_fallback(
64        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
65    )
66    .await?
67    {
68        SessionOrFallback::MaybeSession {
69            cookie_jar,
70            maybe_session,
71            ..
72        } => (cookie_jar, maybe_session),
73        SessionOrFallback::Fallback { response } => return Ok(response),
74    };
75
76    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
77
78    let user_agent = user_agent.map(|ua| ua.to_string());
79
80    let Some(session) = maybe_session else {
81        let login = mas_router::Login::and_continue_device_code_grant(grant_id);
82        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
83    };
84
85    activity_tracker
86        .record_browser_session(&clock, &session)
87        .await;
88
89    // TODO: better error handling
90    let grant = repo
91        .oauth2_device_code_grant()
92        .lookup(grant_id)
93        .await?
94        .context("Device grant not found")
95        .map_err(InternalError::from_anyhow)?;
96
97    if grant.expires_at < clock.now() {
98        return Err(InternalError::from_anyhow(anyhow::anyhow!(
99            "Grant is expired"
100        )));
101    }
102
103    let client = repo
104        .oauth2_client()
105        .lookup(grant.client_id)
106        .await?
107        .context("Client not found")
108        .map_err(InternalError::from_anyhow)?;
109
110    let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
111
112    // We can close the repository early, we don't need it at this point
113    repo.save().await?;
114
115    // Evaluate the policy
116    let res = policy
117        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
118            grant_type: mas_policy::GrantType::DeviceCode,
119            client: &client,
120            session_counts: Some(session_counts),
121            scope: &grant.scope,
122            user: Some(&session.user),
123            requester: mas_policy::Requester {
124                ip_address: activity_tracker.ip(),
125                user_agent,
126            },
127        })
128        .await?;
129    if !res.valid() {
130        warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
131
132        let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
133        let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
134            .with_session(session)
135            .with_csrf(csrf_token.form_value())
136            .with_language(locale);
137
138        let content = templates.render_policy_violation(&ctx)?;
139
140        return Ok((cookie_jar, Html(content)).into_response());
141    }
142
143    // Fetch informations about the user. This is purely cosmetic, so we let it
144    // fail and put a 1s timeout to it in case we fail to query it
145    // XXX: we're likely to need this in other places
146    let localpart = &session.user.username;
147    let display_name = match tokio::time::timeout(
148        Duration::from_secs(1),
149        homeserver.query_user(localpart),
150    )
151    .await
152    {
153        Ok(Ok(user)) => user.displayname,
154        Ok(Err(err)) => {
155            tracing::warn!(
156                error = &*err as &dyn std::error::Error,
157                localpart,
158                "Failed to query user"
159            );
160            None
161        }
162        Err(_) => {
163            tracing::warn!(localpart, "Timed out while querying user");
164            None
165        }
166    };
167
168    let matrix_user = MatrixUser {
169        mxid: homeserver.mxid(localpart),
170        display_name,
171    };
172
173    let ctx = DeviceConsentContext::new(grant, client, matrix_user)
174        .with_session(session)
175        .with_csrf(csrf_token.form_value())
176        .with_language(locale);
177
178    let rendered = templates
179        .render_device_consent(&ctx)
180        .context("Failed to render template")
181        .map_err(InternalError::from_anyhow)?;
182
183    Ok((cookie_jar, Html(rendered)).into_response())
184}
185
186#[tracing::instrument(name = "handlers.oauth2.device.consent.post", skip_all)]
187pub(crate) async fn post(
188    mut rng: BoxRng,
189    clock: BoxClock,
190    PreferredLanguage(locale): PreferredLanguage,
191    State(templates): State<Templates>,
192    State(url_builder): State<UrlBuilder>,
193    State(homeserver): State<Arc<dyn HomeserverConnection>>,
194    mut repo: BoxRepository,
195    mut policy: Policy,
196    activity_tracker: BoundActivityTracker,
197    user_agent: Option<TypedHeader<headers::UserAgent>>,
198    cookie_jar: CookieJar,
199    Path(grant_id): Path<Ulid>,
200    Form(form): Form<ProtectedForm<ConsentForm>>,
201) -> Result<Response, InternalError> {
202    let form = cookie_jar.verify_form(&clock, form)?;
203    let (cookie_jar, maybe_session) = match load_session_or_fallback(
204        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
205    )
206    .await?
207    {
208        SessionOrFallback::MaybeSession {
209            cookie_jar,
210            maybe_session,
211            ..
212        } => (cookie_jar, maybe_session),
213        SessionOrFallback::Fallback { response } => return Ok(response),
214    };
215    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
216
217    let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());
218
219    let Some(session) = maybe_session else {
220        let login = mas_router::Login::and_continue_device_code_grant(grant_id);
221        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
222    };
223
224    activity_tracker
225        .record_browser_session(&clock, &session)
226        .await;
227
228    // TODO: better error handling
229    let grant = repo
230        .oauth2_device_code_grant()
231        .lookup(grant_id)
232        .await?
233        .context("Device grant not found")
234        .map_err(InternalError::from_anyhow)?;
235
236    if grant.expires_at < clock.now() {
237        return Err(InternalError::from_anyhow(anyhow::anyhow!(
238            "Grant is expired"
239        )));
240    }
241
242    let client = repo
243        .oauth2_client()
244        .lookup(grant.client_id)
245        .await?
246        .context("Client not found")
247        .map_err(InternalError::from_anyhow)?;
248
249    let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
250
251    // Evaluate the policy
252    let res = policy
253        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
254            grant_type: mas_policy::GrantType::DeviceCode,
255            client: &client,
256            session_counts: Some(session_counts),
257            scope: &grant.scope,
258            user: Some(&session.user),
259            requester: mas_policy::Requester {
260                ip_address: activity_tracker.ip(),
261                user_agent,
262            },
263        })
264        .await?;
265    if !res.valid() {
266        warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
267
268        let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
269        let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
270            .with_session(session)
271            .with_csrf(csrf_token.form_value())
272            .with_language(locale);
273
274        let content = templates.render_policy_violation(&ctx)?;
275
276        return Ok((cookie_jar, Html(content)).into_response());
277    }
278
279    let grant = if grant.is_pending() {
280        match form.action {
281            Action::Consent => {
282                repo.oauth2_device_code_grant()
283                    .fulfill(&clock, grant, &session)
284                    .await?
285            }
286            Action::Reject => {
287                repo.oauth2_device_code_grant()
288                    .reject(&clock, grant, &session)
289                    .await?
290            }
291        }
292    } else {
293        // XXX: In case we're not pending, let's just return the grant as-is
294        // since it might just be a form resubmission, and feedback is nice enough
295        warn!(
296            oauth2_device_code.id = %grant.id,
297            browser_session.id = %session.id,
298            user.id = %session.user.id,
299            "Grant is not pending",
300        );
301        grant
302    };
303
304    repo.save().await?;
305
306    // Fetch informations about the user. This is purely cosmetic, so we let it
307    // fail and put a 1s timeout to it in case we fail to query it
308    // XXX: we're likely to need this in other places
309    let localpart = &session.user.username;
310    let display_name = match tokio::time::timeout(
311        Duration::from_secs(1),
312        homeserver.query_user(localpart),
313    )
314    .await
315    {
316        Ok(Ok(user)) => user.displayname,
317        Ok(Err(err)) => {
318            tracing::warn!(
319                error = &*err as &dyn std::error::Error,
320                localpart,
321                "Failed to query user"
322            );
323            None
324        }
325        Err(_) => {
326            tracing::warn!(localpart, "Timed out while querying user");
327            None
328        }
329    };
330
331    let matrix_user = MatrixUser {
332        mxid: homeserver.mxid(localpart),
333        display_name,
334    };
335
336    let ctx = DeviceConsentContext::new(grant, client, matrix_user)
337        .with_session(session)
338        .with_csrf(csrf_token.form_value())
339        .with_language(locale);
340
341    let rendered = templates
342        .render_device_consent(&ctx)
343        .context("Failed to render template")
344        .map_err(InternalError::from_anyhow)?;
345
346    Ok((cookie_jar, Html(rendered)).into_response())
347}