mas_handlers/oauth2/device/
consent.rs1use std::{sync::Arc, time::Duration};
8
9use anyhow::Context;
10use axum::{
11 Form,
12 extract::{Path, State},
13 response::{Html, IntoResponse, Response},
14};
15use axum_extra::TypedHeader;
16use mas_axum_utils::{
17 InternalError,
18 cookies::CookieJar,
19 csrf::{CsrfExt, ProtectedForm},
20};
21use mas_data_model::{BoxClock, BoxRng, MatrixUser};
22use mas_matrix::HomeserverConnection;
23use mas_policy::Policy;
24use mas_router::UrlBuilder;
25use mas_storage::BoxRepository;
26use mas_templates::{DeviceConsentContext, PolicyViolationContext, TemplateContext, Templates};
27use serde::Deserialize;
28use tracing::warn;
29use ulid::Ulid;
30
31use crate::{
32 BoundActivityTracker, PreferredLanguage,
33 session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
34};
35
36#[derive(Deserialize, Debug)]
37#[serde(rename_all = "lowercase")]
38enum Action {
39 Consent,
40 Reject,
41}
42
43#[derive(Deserialize, Debug)]
44pub(crate) struct ConsentForm {
45 action: Action,
46}
47
48#[tracing::instrument(name = "handlers.oauth2.device.consent.get", skip_all)]
49pub(crate) async fn get(
50 mut rng: BoxRng,
51 clock: BoxClock,
52 PreferredLanguage(locale): PreferredLanguage,
53 State(templates): State<Templates>,
54 State(url_builder): State<UrlBuilder>,
55 State(homeserver): State<Arc<dyn HomeserverConnection>>,
56 mut repo: BoxRepository,
57 mut policy: Policy,
58 activity_tracker: BoundActivityTracker,
59 user_agent: Option<TypedHeader<headers::UserAgent>>,
60 cookie_jar: CookieJar,
61 Path(grant_id): Path<Ulid>,
62) -> Result<Response, InternalError> {
63 let (cookie_jar, maybe_session) = match load_session_or_fallback(
64 cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
65 )
66 .await?
67 {
68 SessionOrFallback::MaybeSession {
69 cookie_jar,
70 maybe_session,
71 ..
72 } => (cookie_jar, maybe_session),
73 SessionOrFallback::Fallback { response } => return Ok(response),
74 };
75
76 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
77
78 let user_agent = user_agent.map(|ua| ua.to_string());
79
80 let Some(session) = maybe_session else {
81 let login = mas_router::Login::and_continue_device_code_grant(grant_id);
82 return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
83 };
84
85 activity_tracker
86 .record_browser_session(&clock, &session)
87 .await;
88
89 let grant = repo
91 .oauth2_device_code_grant()
92 .lookup(grant_id)
93 .await?
94 .context("Device grant not found")
95 .map_err(InternalError::from_anyhow)?;
96
97 if grant.expires_at < clock.now() {
98 return Err(InternalError::from_anyhow(anyhow::anyhow!(
99 "Grant is expired"
100 )));
101 }
102
103 let client = repo
104 .oauth2_client()
105 .lookup(grant.client_id)
106 .await?
107 .context("Client not found")
108 .map_err(InternalError::from_anyhow)?;
109
110 let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
111
112 repo.save().await?;
114
115 let res = policy
117 .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
118 grant_type: mas_policy::GrantType::DeviceCode,
119 client: &client,
120 session_counts: Some(session_counts),
121 scope: &grant.scope,
122 user: Some(&session.user),
123 requester: mas_policy::Requester {
124 ip_address: activity_tracker.ip(),
125 user_agent,
126 },
127 })
128 .await?;
129 if !res.valid() {
130 warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
131
132 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
133 let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
134 .with_session(session)
135 .with_csrf(csrf_token.form_value())
136 .with_language(locale);
137
138 let content = templates.render_policy_violation(&ctx)?;
139
140 return Ok((cookie_jar, Html(content)).into_response());
141 }
142
143 let localpart = &session.user.username;
147 let display_name = match tokio::time::timeout(
148 Duration::from_secs(1),
149 homeserver.query_user(localpart),
150 )
151 .await
152 {
153 Ok(Ok(user)) => user.displayname,
154 Ok(Err(err)) => {
155 tracing::warn!(
156 error = &*err as &dyn std::error::Error,
157 localpart,
158 "Failed to query user"
159 );
160 None
161 }
162 Err(_) => {
163 tracing::warn!(localpart, "Timed out while querying user");
164 None
165 }
166 };
167
168 let matrix_user = MatrixUser {
169 mxid: homeserver.mxid(localpart),
170 display_name,
171 };
172
173 let ctx = DeviceConsentContext::new(grant, client, matrix_user)
174 .with_session(session)
175 .with_csrf(csrf_token.form_value())
176 .with_language(locale);
177
178 let rendered = templates
179 .render_device_consent(&ctx)
180 .context("Failed to render template")
181 .map_err(InternalError::from_anyhow)?;
182
183 Ok((cookie_jar, Html(rendered)).into_response())
184}
185
186#[tracing::instrument(name = "handlers.oauth2.device.consent.post", skip_all)]
187pub(crate) async fn post(
188 mut rng: BoxRng,
189 clock: BoxClock,
190 PreferredLanguage(locale): PreferredLanguage,
191 State(templates): State<Templates>,
192 State(url_builder): State<UrlBuilder>,
193 State(homeserver): State<Arc<dyn HomeserverConnection>>,
194 mut repo: BoxRepository,
195 mut policy: Policy,
196 activity_tracker: BoundActivityTracker,
197 user_agent: Option<TypedHeader<headers::UserAgent>>,
198 cookie_jar: CookieJar,
199 Path(grant_id): Path<Ulid>,
200 Form(form): Form<ProtectedForm<ConsentForm>>,
201) -> Result<Response, InternalError> {
202 let form = cookie_jar.verify_form(&clock, form)?;
203 let (cookie_jar, maybe_session) = match load_session_or_fallback(
204 cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
205 )
206 .await?
207 {
208 SessionOrFallback::MaybeSession {
209 cookie_jar,
210 maybe_session,
211 ..
212 } => (cookie_jar, maybe_session),
213 SessionOrFallback::Fallback { response } => return Ok(response),
214 };
215 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
216
217 let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());
218
219 let Some(session) = maybe_session else {
220 let login = mas_router::Login::and_continue_device_code_grant(grant_id);
221 return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
222 };
223
224 activity_tracker
225 .record_browser_session(&clock, &session)
226 .await;
227
228 let grant = repo
230 .oauth2_device_code_grant()
231 .lookup(grant_id)
232 .await?
233 .context("Device grant not found")
234 .map_err(InternalError::from_anyhow)?;
235
236 if grant.expires_at < clock.now() {
237 return Err(InternalError::from_anyhow(anyhow::anyhow!(
238 "Grant is expired"
239 )));
240 }
241
242 let client = repo
243 .oauth2_client()
244 .lookup(grant.client_id)
245 .await?
246 .context("Client not found")
247 .map_err(InternalError::from_anyhow)?;
248
249 let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
250
251 let res = policy
253 .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
254 grant_type: mas_policy::GrantType::DeviceCode,
255 client: &client,
256 session_counts: Some(session_counts),
257 scope: &grant.scope,
258 user: Some(&session.user),
259 requester: mas_policy::Requester {
260 ip_address: activity_tracker.ip(),
261 user_agent,
262 },
263 })
264 .await?;
265 if !res.valid() {
266 warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
267
268 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
269 let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
270 .with_session(session)
271 .with_csrf(csrf_token.form_value())
272 .with_language(locale);
273
274 let content = templates.render_policy_violation(&ctx)?;
275
276 return Ok((cookie_jar, Html(content)).into_response());
277 }
278
279 let grant = if grant.is_pending() {
280 match form.action {
281 Action::Consent => {
282 repo.oauth2_device_code_grant()
283 .fulfill(&clock, grant, &session)
284 .await?
285 }
286 Action::Reject => {
287 repo.oauth2_device_code_grant()
288 .reject(&clock, grant, &session)
289 .await?
290 }
291 }
292 } else {
293 warn!(
296 oauth2_device_code.id = %grant.id,
297 browser_session.id = %session.id,
298 user.id = %session.user.id,
299 "Grant is not pending",
300 );
301 grant
302 };
303
304 repo.save().await?;
305
306 let localpart = &session.user.username;
310 let display_name = match tokio::time::timeout(
311 Duration::from_secs(1),
312 homeserver.query_user(localpart),
313 )
314 .await
315 {
316 Ok(Ok(user)) => user.displayname,
317 Ok(Err(err)) => {
318 tracing::warn!(
319 error = &*err as &dyn std::error::Error,
320 localpart,
321 "Failed to query user"
322 );
323 None
324 }
325 Err(_) => {
326 tracing::warn!(localpart, "Timed out while querying user");
327 None
328 }
329 };
330
331 let matrix_user = MatrixUser {
332 mxid: homeserver.mxid(localpart),
333 display_name,
334 };
335
336 let ctx = DeviceConsentContext::new(grant, client, matrix_user)
337 .with_session(session)
338 .with_csrf(csrf_token.form_value())
339 .with_language(locale);
340
341 let rendered = templates
342 .render_device_consent(&ctx)
343 .context("Failed to render template")
344 .map_err(InternalError::from_anyhow)?;
345
346 Ok((cookie_jar, Html(rendered)).into_response())
347}