1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 BrowserSession, Clock, UlidExt as _, UpstreamOAuthAuthorizationSession,
12 UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider,
13};
14use mas_storage::{
15 Page, Pagination,
16 pagination::Node,
17 upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{
21 Expr, ExprTrait, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr,
22};
23use sea_query_sqlx::SqlxBinder;
24use sqlx::PgConnection;
25use ulid::Ulid;
26use uuid::Uuid;
27
28use crate::{
29 DatabaseError, DatabaseInconsistencyError,
30 filter::{Filter, StatementExt},
31 iden::UpstreamOAuthAuthorizationSessions,
32 pagination::QueryBuilderExt,
33 tracing::ExecuteExt,
34};
35
36impl Filter for UpstreamOAuthSessionFilter<'_> {
37 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
38 sea_query::Condition::all()
39 .add_option(self.provider().map(|provider| {
40 Expr::col((
41 UpstreamOAuthAuthorizationSessions::Table,
42 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
43 ))
44 .eq(Uuid::from(provider.id))
45 }))
46 .add_option(self.sub_claim().map(|sub| {
47 Expr::col((
48 UpstreamOAuthAuthorizationSessions::Table,
49 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
50 ))
51 .cast_json_field("sub")
52 .eq(sub)
53 }))
54 .add_option(self.sid_claim().map(|sid| {
55 Expr::col((
56 UpstreamOAuthAuthorizationSessions::Table,
57 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
58 ))
59 .cast_json_field("sid")
60 .eq(sid)
61 }))
62 }
63}
64
65pub struct PgUpstreamOAuthSessionRepository<'c> {
68 conn: &'c mut PgConnection,
69}
70
71impl<'c> PgUpstreamOAuthSessionRepository<'c> {
72 pub fn new(conn: &'c mut PgConnection) -> Self {
75 Self { conn }
76 }
77}
78
79#[derive(sqlx::FromRow)]
80#[enum_def]
81struct SessionLookup {
82 upstream_oauth_authorization_session_id: Uuid,
83 upstream_oauth_provider_id: Uuid,
84 upstream_oauth_link_id: Option<Uuid>,
85 state: String,
86 code_challenge_verifier: Option<String>,
87 nonce: Option<String>,
88 id_token: Option<String>,
89 id_token_claims: Option<serde_json::Value>,
90 userinfo: Option<serde_json::Value>,
91 created_at: DateTime<Utc>,
92 completed_at: Option<DateTime<Utc>>,
93 consumed_at: Option<DateTime<Utc>>,
94 extra_callback_parameters: Option<serde_json::Value>,
95 unlinked_at: Option<DateTime<Utc>>,
96}
97
98impl Node<Ulid> for SessionLookup {
99 fn cursor(&self) -> Ulid {
100 self.upstream_oauth_authorization_session_id.into()
101 }
102}
103
104impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
105 type Error = DatabaseInconsistencyError;
106
107 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
108 let id = value.upstream_oauth_authorization_session_id.into();
109 let state = match (
110 value.upstream_oauth_link_id,
111 value.id_token,
112 value.id_token_claims,
113 value.extra_callback_parameters,
114 value.userinfo,
115 value.completed_at,
116 value.consumed_at,
117 value.unlinked_at,
118 ) {
119 (None, None, None, None, None, None, None, None) => {
120 UpstreamOAuthAuthorizationSessionState::Pending
121 }
122 (
123 Some(link_id),
124 id_token,
125 id_token_claims,
126 extra_callback_parameters,
127 userinfo,
128 Some(completed_at),
129 None,
130 None,
131 ) => UpstreamOAuthAuthorizationSessionState::Completed {
132 completed_at,
133 link_id: link_id.into(),
134 id_token,
135 id_token_claims,
136 extra_callback_parameters,
137 userinfo,
138 },
139 (
140 Some(link_id),
141 id_token,
142 id_token_claims,
143 extra_callback_parameters,
144 userinfo,
145 Some(completed_at),
146 Some(consumed_at),
147 None,
148 ) => UpstreamOAuthAuthorizationSessionState::Consumed {
149 completed_at,
150 link_id: link_id.into(),
151 id_token,
152 id_token_claims,
153 extra_callback_parameters,
154 userinfo,
155 consumed_at,
156 },
157 (
158 _,
159 id_token,
160 id_token_claims,
161 _,
162 _,
163 Some(completed_at),
164 consumed_at,
165 Some(unlinked_at),
166 ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
167 completed_at,
168 id_token,
169 id_token_claims,
170 consumed_at,
171 unlinked_at,
172 },
173 _ => {
174 return Err(DatabaseInconsistencyError::on(
175 "upstream_oauth_authorization_sessions",
176 )
177 .row(id));
178 }
179 };
180
181 Ok(Self {
182 id,
183 provider_id: value.upstream_oauth_provider_id.into(),
184 state_str: value.state,
185 nonce: value.nonce,
186 code_challenge_verifier: value.code_challenge_verifier,
187 created_at: value.created_at,
188 state,
189 })
190 }
191}
192
193#[async_trait]
194impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
195 type Error = DatabaseError;
196
197 #[tracing::instrument(
198 name = "db.upstream_oauth_authorization_session.lookup",
199 skip_all,
200 fields(
201 db.query.text,
202 upstream_oauth_provider.id = %id,
203 ),
204 err,
205 )]
206 async fn lookup(
207 &mut self,
208 id: Ulid,
209 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
210 let res = sqlx::query_as!(
211 SessionLookup,
212 r#"
213 SELECT
214 upstream_oauth_authorization_session_id,
215 upstream_oauth_provider_id,
216 upstream_oauth_link_id,
217 state,
218 code_challenge_verifier,
219 nonce,
220 id_token,
221 id_token_claims,
222 extra_callback_parameters,
223 userinfo,
224 created_at,
225 completed_at,
226 consumed_at,
227 unlinked_at
228 FROM upstream_oauth_authorization_sessions
229 WHERE upstream_oauth_authorization_session_id = $1
230 "#,
231 Uuid::from(id),
232 )
233 .traced()
234 .fetch_optional(&mut *self.conn)
235 .await?;
236
237 let Some(res) = res else { return Ok(None) };
238
239 Ok(Some(res.try_into()?))
240 }
241
242 #[tracing::instrument(
243 name = "db.upstream_oauth_authorization_session.add",
244 skip_all,
245 fields(
246 db.query.text,
247 %upstream_oauth_provider.id,
248 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
249 %upstream_oauth_provider.client_id,
250 upstream_oauth_authorization_session.id,
251 ),
252 err,
253 )]
254 async fn add(
255 &mut self,
256 rng: &mut (dyn RngCore + Send),
257 clock: &dyn Clock,
258 upstream_oauth_provider: &UpstreamOAuthProvider,
259 state_str: String,
260 code_challenge_verifier: Option<String>,
261 nonce: Option<String>,
262 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
263 let created_at = clock.now();
264 let id = Ulid::from_datetime_with_rng(created_at, rng);
265 tracing::Span::current().record(
266 "upstream_oauth_authorization_session.id",
267 tracing::field::display(id),
268 );
269
270 sqlx::query!(
271 r#"
272 INSERT INTO upstream_oauth_authorization_sessions (
273 upstream_oauth_authorization_session_id,
274 upstream_oauth_provider_id,
275 state,
276 code_challenge_verifier,
277 nonce,
278 created_at,
279 completed_at,
280 consumed_at,
281 id_token,
282 userinfo
283 ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
284 "#,
285 Uuid::from(id),
286 Uuid::from(upstream_oauth_provider.id),
287 &state_str,
288 code_challenge_verifier.as_deref(),
289 nonce,
290 created_at,
291 )
292 .traced()
293 .execute(&mut *self.conn)
294 .await?;
295
296 Ok(UpstreamOAuthAuthorizationSession {
297 id,
298 state: UpstreamOAuthAuthorizationSessionState::default(),
299 provider_id: upstream_oauth_provider.id,
300 state_str,
301 code_challenge_verifier,
302 nonce,
303 created_at,
304 })
305 }
306
307 #[tracing::instrument(
308 name = "db.upstream_oauth_authorization_session.complete_with_link",
309 skip_all,
310 fields(
311 db.query.text,
312 %upstream_oauth_authorization_session.id,
313 %upstream_oauth_link.id,
314 ),
315 err,
316 )]
317 async fn complete_with_link(
318 &mut self,
319 clock: &dyn Clock,
320 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
321 upstream_oauth_link: &UpstreamOAuthLink,
322 id_token: Option<String>,
323 id_token_claims: Option<serde_json::Value>,
324 extra_callback_parameters: Option<serde_json::Value>,
325 userinfo: Option<serde_json::Value>,
326 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
327 let completed_at = clock.now();
328
329 sqlx::query!(
330 r#"
331 UPDATE upstream_oauth_authorization_sessions
332 SET upstream_oauth_link_id = $1
333 , completed_at = $2
334 , id_token = $3
335 , id_token_claims = $4
336 , extra_callback_parameters = $5
337 , userinfo = $6
338 WHERE upstream_oauth_authorization_session_id = $7
339 "#,
340 Uuid::from(upstream_oauth_link.id),
341 completed_at,
342 id_token,
343 id_token_claims,
344 extra_callback_parameters,
345 userinfo,
346 Uuid::from(upstream_oauth_authorization_session.id),
347 )
348 .traced()
349 .execute(&mut *self.conn)
350 .await?;
351
352 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
353 .complete(
354 completed_at,
355 upstream_oauth_link,
356 id_token,
357 id_token_claims,
358 extra_callback_parameters,
359 userinfo,
360 )
361 .map_err(DatabaseError::to_invalid_operation)?;
362
363 Ok(upstream_oauth_authorization_session)
364 }
365
366 #[tracing::instrument(
368 name = "db.upstream_oauth_authorization_session.consume",
369 skip_all,
370 fields(
371 db.query.text,
372 %upstream_oauth_authorization_session.id,
373 ),
374 err,
375 )]
376 async fn consume(
377 &mut self,
378 clock: &dyn Clock,
379 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
380 browser_session: &BrowserSession,
381 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
382 let consumed_at = clock.now();
383 sqlx::query!(
384 r#"
385 UPDATE upstream_oauth_authorization_sessions
386 SET consumed_at = $1,
387 user_session_id = $2
388 WHERE upstream_oauth_authorization_session_id = $3
389 "#,
390 consumed_at,
391 Uuid::from(browser_session.id),
392 Uuid::from(upstream_oauth_authorization_session.id),
393 )
394 .traced()
395 .execute(&mut *self.conn)
396 .await?;
397
398 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
399 .consume(consumed_at)
400 .map_err(DatabaseError::to_invalid_operation)?;
401
402 Ok(upstream_oauth_authorization_session)
403 }
404
405 #[tracing::instrument(
406 name = "db.upstream_oauth_authorization_session.list",
407 skip_all,
408 fields(
409 db.query.text,
410 ),
411 err,
412 )]
413 async fn list(
414 &mut self,
415 filter: UpstreamOAuthSessionFilter<'_>,
416 pagination: Pagination,
417 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
418 let (sql, arguments) = Query::select()
419 .expr_as(
420 Expr::col((
421 UpstreamOAuthAuthorizationSessions::Table,
422 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
423 )),
424 SessionLookupIden::UpstreamOauthAuthorizationSessionId,
425 )
426 .expr_as(
427 Expr::col((
428 UpstreamOAuthAuthorizationSessions::Table,
429 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
430 )),
431 SessionLookupIden::UpstreamOauthProviderId,
432 )
433 .expr_as(
434 Expr::col((
435 UpstreamOAuthAuthorizationSessions::Table,
436 UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
437 )),
438 SessionLookupIden::UpstreamOauthLinkId,
439 )
440 .expr_as(
441 Expr::col((
442 UpstreamOAuthAuthorizationSessions::Table,
443 UpstreamOAuthAuthorizationSessions::State,
444 )),
445 SessionLookupIden::State,
446 )
447 .expr_as(
448 Expr::col((
449 UpstreamOAuthAuthorizationSessions::Table,
450 UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
451 )),
452 SessionLookupIden::CodeChallengeVerifier,
453 )
454 .expr_as(
455 Expr::col((
456 UpstreamOAuthAuthorizationSessions::Table,
457 UpstreamOAuthAuthorizationSessions::Nonce,
458 )),
459 SessionLookupIden::Nonce,
460 )
461 .expr_as(
462 Expr::col((
463 UpstreamOAuthAuthorizationSessions::Table,
464 UpstreamOAuthAuthorizationSessions::IdToken,
465 )),
466 SessionLookupIden::IdToken,
467 )
468 .expr_as(
469 Expr::col((
470 UpstreamOAuthAuthorizationSessions::Table,
471 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
472 )),
473 SessionLookupIden::IdTokenClaims,
474 )
475 .expr_as(
476 Expr::col((
477 UpstreamOAuthAuthorizationSessions::Table,
478 UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
479 )),
480 SessionLookupIden::ExtraCallbackParameters,
481 )
482 .expr_as(
483 Expr::col((
484 UpstreamOAuthAuthorizationSessions::Table,
485 UpstreamOAuthAuthorizationSessions::Userinfo,
486 )),
487 SessionLookupIden::Userinfo,
488 )
489 .expr_as(
490 Expr::col((
491 UpstreamOAuthAuthorizationSessions::Table,
492 UpstreamOAuthAuthorizationSessions::CreatedAt,
493 )),
494 SessionLookupIden::CreatedAt,
495 )
496 .expr_as(
497 Expr::col((
498 UpstreamOAuthAuthorizationSessions::Table,
499 UpstreamOAuthAuthorizationSessions::CompletedAt,
500 )),
501 SessionLookupIden::CompletedAt,
502 )
503 .expr_as(
504 Expr::col((
505 UpstreamOAuthAuthorizationSessions::Table,
506 UpstreamOAuthAuthorizationSessions::ConsumedAt,
507 )),
508 SessionLookupIden::ConsumedAt,
509 )
510 .expr_as(
511 Expr::col((
512 UpstreamOAuthAuthorizationSessions::Table,
513 UpstreamOAuthAuthorizationSessions::UnlinkedAt,
514 )),
515 SessionLookupIden::UnlinkedAt,
516 )
517 .from(UpstreamOAuthAuthorizationSessions::Table)
518 .apply_filter(filter)
519 .generate_pagination(
520 (
521 UpstreamOAuthAuthorizationSessions::Table,
522 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
523 ),
524 pagination,
525 )
526 .build_sqlx(PostgresQueryBuilder);
527
528 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
529 .traced()
530 .fetch_all(&mut *self.conn)
531 .await?;
532
533 let page = pagination
534 .process(edges)
535 .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
536
537 Ok(page)
538 }
539
540 #[tracing::instrument(
541 name = "db.upstream_oauth_authorization_session.count",
542 skip_all,
543 fields(
544 db.query.text,
545 ),
546 err,
547 )]
548 async fn count(
549 &mut self,
550 filter: UpstreamOAuthSessionFilter<'_>,
551 ) -> Result<usize, Self::Error> {
552 let (sql, arguments) = Query::select()
553 .expr(
554 Expr::col((
555 UpstreamOAuthAuthorizationSessions::Table,
556 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
557 ))
558 .count(),
559 )
560 .from(UpstreamOAuthAuthorizationSessions::Table)
561 .apply_filter(filter)
562 .build_sqlx(PostgresQueryBuilder);
563
564 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
565 .traced()
566 .fetch_one(&mut *self.conn)
567 .await?;
568
569 count
570 .try_into()
571 .map_err(DatabaseError::to_invalid_operation)
572 }
573
574 #[tracing::instrument(
575 name = "db.upstream_oauth_authorization_session.cleanup",
576 skip_all,
577 fields(
578 db.query.text,
579 since = since.map(tracing::field::display),
580 until = %until,
581 limit = limit,
582 ),
583 err,
584 )]
585 async fn cleanup_orphaned(
586 &mut self,
587 since: Option<Ulid>,
588 until: Ulid,
589 limit: usize,
590 ) -> Result<(usize, Option<Ulid>), Self::Error> {
591 let res = sqlx::query_scalar!(
595 r#"
596 WITH to_delete AS (
597 SELECT upstream_oauth_authorization_session_id
598 FROM upstream_oauth_authorization_sessions
599 WHERE ($1::uuid IS NULL OR upstream_oauth_authorization_session_id > $1)
600 AND upstream_oauth_authorization_session_id <= $2
601 AND user_session_id IS NULL
602 ORDER BY upstream_oauth_authorization_session_id
603 LIMIT $3
604 )
605 DELETE FROM upstream_oauth_authorization_sessions
606 USING to_delete
607 WHERE upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id = to_delete.upstream_oauth_authorization_session_id
608 RETURNING upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id
609 "#,
610 since.map(Uuid::from),
611 Uuid::from(until),
612 i64::try_from(limit).unwrap_or(i64::MAX)
613 )
614 .traced()
615 .fetch_all(&mut *self.conn)
616 .await?;
617
618 let count = res.len();
619 let max_id = res.into_iter().max();
620
621 Ok((count, max_id.map(Ulid::from)))
622 }
623}