1use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13 BrowserSession, Clock, DeviceCodeGrant, DeviceCodeGrantState, Session, UlidExt as _,
14};
15use mas_storage::oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository};
16use oauth2_types::scope::Scope;
17use rand::RngCore;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use uuid::Uuid;
21
22use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
23
24pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
27 conn: &'c mut PgConnection,
28}
29
30impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
31 pub fn new(conn: &'c mut PgConnection) -> Self {
34 Self { conn }
35 }
36}
37
38struct OAuth2DeviceGrantLookup {
39 oauth2_device_code_grant_id: Uuid,
40 oauth2_client_id: Uuid,
41 scope: String,
42 device_code: String,
43 user_code: String,
44 created_at: DateTime<Utc>,
45 expires_at: DateTime<Utc>,
46 fulfilled_at: Option<DateTime<Utc>>,
47 rejected_at: Option<DateTime<Utc>>,
48 exchanged_at: Option<DateTime<Utc>>,
49 user_session_id: Option<Uuid>,
50 oauth2_session_id: Option<Uuid>,
51 ip_address: Option<IpAddr>,
52 user_agent: Option<String>,
53 locale: Option<String>,
54}
55
56impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
57 type Error = DatabaseInconsistencyError;
58
59 fn try_from(
60 OAuth2DeviceGrantLookup {
61 oauth2_device_code_grant_id,
62 oauth2_client_id,
63 scope,
64 device_code,
65 user_code,
66 created_at,
67 expires_at,
68 fulfilled_at,
69 rejected_at,
70 exchanged_at,
71 user_session_id,
72 oauth2_session_id,
73 ip_address,
74 user_agent,
75 locale,
76 }: OAuth2DeviceGrantLookup,
77 ) -> Result<Self, Self::Error> {
78 let id = Ulid::from(oauth2_device_code_grant_id);
79 let client_id = Ulid::from(oauth2_client_id);
80
81 let scope: Scope = scope.parse().map_err(|e| {
82 DatabaseInconsistencyError::on("oauth2_authorization_grants")
83 .column("scope")
84 .row(id)
85 .source(e)
86 })?;
87
88 let state = match (
89 fulfilled_at,
90 rejected_at,
91 exchanged_at,
92 user_session_id,
93 oauth2_session_id,
94 ) {
95 (None, None, None, None, None) => DeviceCodeGrantState::Pending,
96
97 (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
98 DeviceCodeGrantState::Fulfilled {
99 browser_session_id: Ulid::from(user_session_id),
100 fulfilled_at,
101 }
102 }
103
104 (None, Some(rejected_at), None, Some(user_session_id), None) => {
105 DeviceCodeGrantState::Rejected {
106 browser_session_id: Ulid::from(user_session_id),
107 rejected_at,
108 }
109 }
110
111 (
112 Some(fulfilled_at),
113 None,
114 Some(exchanged_at),
115 Some(user_session_id),
116 Some(oauth2_session_id),
117 ) => DeviceCodeGrantState::Exchanged {
118 browser_session_id: Ulid::from(user_session_id),
119 session_id: Ulid::from(oauth2_session_id),
120 fulfilled_at,
121 exchanged_at,
122 },
123
124 _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
125 };
126
127 Ok(DeviceCodeGrant {
128 id,
129 state,
130 client_id,
131 scope,
132 user_code,
133 device_code,
134 created_at,
135 expires_at,
136 ip_address,
137 user_agent,
138 locale,
139 })
140 }
141}
142
143#[async_trait]
144impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
145 type Error = DatabaseError;
146
147 #[tracing::instrument(
148 name = "db.oauth2_device_code_grant.add",
149 skip_all,
150 fields(
151 db.query.text,
152 oauth2_device_code.id,
153 oauth2_device_code.scope = %params.scope,
154 oauth2_client.id = %params.client.id,
155 ),
156 err,
157 )]
158 async fn add(
159 &mut self,
160 rng: &mut (dyn RngCore + Send),
161 clock: &dyn Clock,
162 params: OAuth2DeviceCodeGrantParams<'_>,
163 ) -> Result<DeviceCodeGrant, Self::Error> {
164 let now = clock.now();
165 let id = Ulid::from_datetime_with_rng(now, rng);
166 tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
167
168 let created_at = now;
169 let expires_at = now + params.expires_in;
170 let client_id = params.client.id;
171
172 sqlx::query!(
173 r#"
174 INSERT INTO "oauth2_device_code_grant"
175 ( oauth2_device_code_grant_id
176 , oauth2_client_id
177 , scope
178 , device_code
179 , user_code
180 , created_at
181 , expires_at
182 , ip_address
183 , user_agent
184 )
185 VALUES
186 ($1, $2, $3, $4, $5, $6, $7, $8, $9)
187 "#,
188 Uuid::from(id),
189 Uuid::from(client_id),
190 params.scope.to_string(),
191 ¶ms.device_code,
192 ¶ms.user_code,
193 created_at,
194 expires_at,
195 params.ip_address as Option<IpAddr>,
196 params.user_agent.as_deref(),
197 )
198 .traced()
199 .execute(&mut *self.conn)
200 .await?;
201
202 Ok(DeviceCodeGrant {
203 id,
204 state: DeviceCodeGrantState::Pending,
205 client_id,
206 scope: params.scope,
207 user_code: params.user_code,
208 device_code: params.device_code,
209 created_at,
210 expires_at,
211 ip_address: params.ip_address,
212 user_agent: params.user_agent,
213 locale: None,
214 })
215 }
216
217 #[tracing::instrument(
218 name = "db.oauth2_device_code_grant.lookup",
219 skip_all,
220 fields(
221 db.query.text,
222 oauth2_device_code.id = %id,
223 ),
224 err,
225 )]
226 async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
227 let res = sqlx::query_as!(
228 OAuth2DeviceGrantLookup,
229 r#"
230 SELECT oauth2_device_code_grant_id
231 , oauth2_client_id
232 , scope
233 , device_code
234 , user_code
235 , created_at
236 , expires_at
237 , fulfilled_at
238 , rejected_at
239 , exchanged_at
240 , user_session_id
241 , oauth2_session_id
242 , ip_address as "ip_address: IpAddr"
243 , user_agent
244 , locale
245 FROM
246 oauth2_device_code_grant
247
248 WHERE oauth2_device_code_grant_id = $1
249 "#,
250 Uuid::from(id),
251 )
252 .traced()
253 .fetch_optional(&mut *self.conn)
254 .await?;
255
256 let Some(res) = res else { return Ok(None) };
257
258 Ok(Some(res.try_into()?))
259 }
260
261 #[tracing::instrument(
262 name = "db.oauth2_device_code_grant.find_by_user_code",
263 skip_all,
264 fields(
265 db.query.text,
266 oauth2_device_code.user_code = %user_code,
267 ),
268 err,
269 )]
270 async fn find_by_user_code(
271 &mut self,
272 user_code: &str,
273 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
274 let res = sqlx::query_as!(
275 OAuth2DeviceGrantLookup,
276 r#"
277 SELECT oauth2_device_code_grant_id
278 , oauth2_client_id
279 , scope
280 , device_code
281 , user_code
282 , created_at
283 , expires_at
284 , fulfilled_at
285 , rejected_at
286 , exchanged_at
287 , user_session_id
288 , oauth2_session_id
289 , ip_address as "ip_address: IpAddr"
290 , user_agent
291 , locale
292 FROM
293 oauth2_device_code_grant
294
295 WHERE user_code = $1
296 "#,
297 user_code,
298 )
299 .traced()
300 .fetch_optional(&mut *self.conn)
301 .await?;
302
303 let Some(res) = res else { return Ok(None) };
304
305 Ok(Some(res.try_into()?))
306 }
307
308 #[tracing::instrument(
309 name = "db.oauth2_device_code_grant.find_by_device_code",
310 skip_all,
311 fields(
312 db.query.text,
313 oauth2_device_code.device_code = %device_code,
314 ),
315 err,
316 )]
317 async fn find_by_device_code(
318 &mut self,
319 device_code: &str,
320 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
321 let res = sqlx::query_as!(
322 OAuth2DeviceGrantLookup,
323 r#"
324 SELECT oauth2_device_code_grant_id
325 , oauth2_client_id
326 , scope
327 , device_code
328 , user_code
329 , created_at
330 , expires_at
331 , fulfilled_at
332 , rejected_at
333 , exchanged_at
334 , user_session_id
335 , oauth2_session_id
336 , ip_address as "ip_address: IpAddr"
337 , user_agent
338 , locale
339 FROM
340 oauth2_device_code_grant
341
342 WHERE device_code = $1
343 "#,
344 device_code,
345 )
346 .traced()
347 .fetch_optional(&mut *self.conn)
348 .await?;
349
350 let Some(res) = res else { return Ok(None) };
351
352 Ok(Some(res.try_into()?))
353 }
354
355 #[tracing::instrument(
356 name = "db.oauth2_device_code_grant.fulfill",
357 skip_all,
358 fields(
359 db.query.text,
360 oauth2_device_code.id = %device_code_grant.id,
361 oauth2_client.id = %device_code_grant.client_id,
362 browser_session.id = %browser_session.id,
363 user.id = %browser_session.user.id,
364 ),
365 err,
366 )]
367 async fn fulfill(
368 &mut self,
369 clock: &dyn Clock,
370 device_code_grant: DeviceCodeGrant,
371 browser_session: &BrowserSession,
372 locale: Option<String>,
373 ) -> Result<DeviceCodeGrant, Self::Error> {
374 let fulfilled_at = clock.now();
375 let device_code_grant = device_code_grant
376 .fulfill(browser_session, locale, fulfilled_at)
377 .map_err(DatabaseError::to_invalid_operation)?;
378
379 let res = sqlx::query!(
380 r#"
381 UPDATE oauth2_device_code_grant
382 SET fulfilled_at = $1
383 , user_session_id = $2
384 , locale = $3
385 WHERE oauth2_device_code_grant_id = $4
386 "#,
387 fulfilled_at,
388 Uuid::from(browser_session.id),
389 device_code_grant.locale.as_deref(),
390 Uuid::from(device_code_grant.id),
391 )
392 .traced()
393 .execute(&mut *self.conn)
394 .await?;
395
396 DatabaseError::ensure_affected_rows(&res, 1)?;
397
398 Ok(device_code_grant)
399 }
400
401 #[tracing::instrument(
402 name = "db.oauth2_device_code_grant.reject",
403 skip_all,
404 fields(
405 db.query.text,
406 oauth2_device_code.id = %device_code_grant.id,
407 oauth2_client.id = %device_code_grant.client_id,
408 browser_session.id = %browser_session.id,
409 user.id = %browser_session.user.id,
410 ),
411 err,
412 )]
413 async fn reject(
414 &mut self,
415 clock: &dyn Clock,
416 device_code_grant: DeviceCodeGrant,
417 browser_session: &BrowserSession,
418 ) -> Result<DeviceCodeGrant, Self::Error> {
419 let fulfilled_at = clock.now();
420 let device_code_grant = device_code_grant
421 .reject(browser_session, fulfilled_at)
422 .map_err(DatabaseError::to_invalid_operation)?;
423
424 let res = sqlx::query!(
425 r#"
426 UPDATE oauth2_device_code_grant
427 SET rejected_at = $1
428 , user_session_id = $2
429 WHERE oauth2_device_code_grant_id = $3
430 "#,
431 fulfilled_at,
432 Uuid::from(browser_session.id),
433 Uuid::from(device_code_grant.id),
434 )
435 .traced()
436 .execute(&mut *self.conn)
437 .await?;
438
439 DatabaseError::ensure_affected_rows(&res, 1)?;
440
441 Ok(device_code_grant)
442 }
443
444 #[tracing::instrument(
445 name = "db.oauth2_device_code_grant.exchange",
446 skip_all,
447 fields(
448 db.query.text,
449 oauth2_device_code.id = %device_code_grant.id,
450 oauth2_client.id = %device_code_grant.client_id,
451 oauth2_session.id = %session.id,
452 ),
453 err,
454 )]
455 async fn exchange(
456 &mut self,
457 clock: &dyn Clock,
458 device_code_grant: DeviceCodeGrant,
459 session: &Session,
460 ) -> Result<DeviceCodeGrant, Self::Error> {
461 let exchanged_at = clock.now();
462 let device_code_grant = device_code_grant
463 .exchange(session, exchanged_at)
464 .map_err(DatabaseError::to_invalid_operation)?;
465
466 let res = sqlx::query!(
467 r#"
468 UPDATE oauth2_device_code_grant
469 SET exchanged_at = $1
470 , oauth2_session_id = $2
471 WHERE oauth2_device_code_grant_id = $3
472 "#,
473 exchanged_at,
474 Uuid::from(session.id),
475 Uuid::from(device_code_grant.id),
476 )
477 .traced()
478 .execute(&mut *self.conn)
479 .await?;
480
481 DatabaseError::ensure_affected_rows(&res, 1)?;
482
483 Ok(device_code_grant)
484 }
485
486 #[tracing::instrument(
487 name = "db.oauth2_device_code_grant.cleanup",
488 skip_all,
489 fields(
490 db.query.text,
491 since = since.map(tracing::field::display),
492 until = %until,
493 limit = limit,
494 ),
495 err,
496 )]
497 async fn cleanup(
498 &mut self,
499 since: Option<Ulid>,
500 until: Ulid,
501 limit: usize,
502 ) -> Result<(usize, Option<Ulid>), Self::Error> {
503 let res = sqlx::query_scalar!(
508 r#"
509 WITH to_delete AS (
510 SELECT oauth2_device_code_grant_id
511 FROM oauth2_device_code_grant
512 WHERE ($1::uuid IS NULL OR oauth2_device_code_grant_id > $1)
513 AND oauth2_device_code_grant_id <= $2
514 ORDER BY oauth2_device_code_grant_id
515 LIMIT $3
516 )
517 DELETE FROM oauth2_device_code_grant
518 USING to_delete
519 WHERE oauth2_device_code_grant.oauth2_device_code_grant_id = to_delete.oauth2_device_code_grant_id
520 RETURNING oauth2_device_code_grant.oauth2_device_code_grant_id
521 "#,
522 since.map(Uuid::from),
523 Uuid::from(until),
524 i64::try_from(limit).unwrap_or(i64::MAX)
525 )
526 .traced()
527 .fetch_all(&mut *self.conn)
528 .await?;
529
530 let count = res.len();
531 let max_id = res.into_iter().max();
532
533 Ok((count, max_id.map(Ulid::from)))
534 }
535}