1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 Clock, UlidExt as _, User,
13 personal::{
14 PersonalAccessToken,
15 session::{PersonalSession, PersonalSessionOwner, SessionState},
16 },
17};
18use mas_storage::{
19 Page, Pagination,
20 pagination::Node,
21 personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
22};
23use oauth2_types::scope::Scope;
24use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
25use rand::RngCore;
26use sea_query::{
27 Cond, Condition, Expr, ExprTrait, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
28 extension::postgres::PgExpr as _,
29};
30use sea_query_sqlx::SqlxBinder as _;
31use sqlx::PgConnection;
32use tracing::{Instrument as _, info_span};
33use ulid::Ulid;
34use uuid::Uuid;
35
36use crate::{
37 DatabaseError,
38 errors::DatabaseInconsistencyError,
39 filter::{Filter, StatementExt as _},
40 iden::{PersonalAccessTokens, PersonalSessions},
41 pagination::QueryBuilderExt as _,
42 tracing::ExecuteExt as _,
43};
44
45pub struct PgPersonalSessionRepository<'c> {
48 conn: &'c mut PgConnection,
49}
50
51impl<'c> PgPersonalSessionRepository<'c> {
52 pub fn new(conn: &'c mut PgConnection) -> Self {
55 Self { conn }
56 }
57}
58
59#[derive(sqlx::FromRow)]
60#[enum_def]
61struct PersonalSessionLookup {
62 personal_session_id: Uuid,
63 owner_user_id: Option<Uuid>,
64 owner_oauth2_client_id: Option<Uuid>,
65 actor_user_id: Uuid,
66 human_name: String,
67 scope_list: Vec<String>,
68 created_at: DateTime<Utc>,
69 revoked_at: Option<DateTime<Utc>>,
70 last_active_at: Option<DateTime<Utc>>,
71 last_active_ip: Option<IpAddr>,
72}
73
74impl Node<Ulid> for PersonalSessionLookup {
75 fn cursor(&self) -> Ulid {
76 self.personal_session_id.into()
77 }
78}
79
80impl TryFrom<PersonalSessionLookup> for PersonalSession {
81 type Error = DatabaseInconsistencyError;
82
83 fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
84 let id = Ulid::from(value.personal_session_id);
85 let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
86 let scope = scope.map_err(|e| {
87 DatabaseInconsistencyError::on("personal_sessions")
88 .column("scope")
89 .row(id)
90 .source(e)
91 })?;
92
93 let state = match value.revoked_at {
94 None => SessionState::Valid,
95 Some(revoked_at) => SessionState::Revoked { revoked_at },
96 };
97
98 let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
99 (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
100 (None, Some(owner_oauth2_client_id)) => {
101 PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
102 }
103 _ => {
104 return Err(DatabaseInconsistencyError::on("personal_sessions")
106 .column("owner_user_id, owner_oauth2_client_id")
107 .row(id));
108 }
109 };
110
111 Ok(PersonalSession {
112 id,
113 state,
114 owner,
115 actor_user_id: Ulid::from(value.actor_user_id),
116 human_name: value.human_name,
117 scope,
118 created_at: value.created_at,
119 last_active_at: value.last_active_at,
120 last_active_ip: value.last_active_ip,
121 })
122 }
123}
124
125#[derive(sqlx::FromRow)]
126#[enum_def]
127struct PersonalSessionAndAccessTokenLookup {
128 personal_session_id: Uuid,
129 owner_user_id: Option<Uuid>,
130 owner_oauth2_client_id: Option<Uuid>,
131 actor_user_id: Uuid,
132 human_name: String,
133 scope_list: Vec<String>,
134 created_at: DateTime<Utc>,
135 revoked_at: Option<DateTime<Utc>>,
136 last_active_at: Option<DateTime<Utc>>,
137 last_active_ip: Option<IpAddr>,
138
139 personal_access_token_id: Option<Uuid>,
141 token_created_at: Option<DateTime<Utc>>,
142 token_expires_at: Option<DateTime<Utc>>,
143}
144
145impl Node<Ulid> for PersonalSessionAndAccessTokenLookup {
146 fn cursor(&self) -> Ulid {
147 self.personal_session_id.into()
148 }
149}
150
151impl TryFrom<PersonalSessionAndAccessTokenLookup>
152 for (PersonalSession, Option<PersonalAccessToken>)
153{
154 type Error = DatabaseInconsistencyError;
155
156 fn try_from(value: PersonalSessionAndAccessTokenLookup) -> Result<Self, Self::Error> {
157 let session = PersonalSession::try_from(PersonalSessionLookup {
158 personal_session_id: value.personal_session_id,
159 owner_user_id: value.owner_user_id,
160 owner_oauth2_client_id: value.owner_oauth2_client_id,
161 actor_user_id: value.actor_user_id,
162 human_name: value.human_name,
163 scope_list: value.scope_list,
164 created_at: value.created_at,
165 revoked_at: value.revoked_at,
166 last_active_at: value.last_active_at,
167 last_active_ip: value.last_active_ip,
168 })?;
169
170 let token_opt = if let Some(id) = value.personal_access_token_id {
171 let id = Ulid::from(id);
172 Some(PersonalAccessToken {
173 id,
174 session_id: session.id,
175 created_at: value.token_created_at.ok_or(
177 DatabaseInconsistencyError::on("personal_sessions")
178 .column("created_at")
179 .row(id),
180 )?,
181 expires_at: value.token_expires_at,
182 revoked_at: None,
183 })
184 } else {
185 None
186 };
187
188 Ok((session, token_opt))
189 }
190}
191
192#[async_trait]
193impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
194 type Error = DatabaseError;
195
196 #[tracing::instrument(
197 name = "db.personal_session.lookup",
198 skip_all,
199 fields(
200 db.query.text,
201 session.id = %id,
202 ),
203 err,
204 )]
205 async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
206 let res = sqlx::query_as!(
207 PersonalSessionLookup,
208 r#"
209 SELECT personal_session_id
210 , owner_user_id
211 , owner_oauth2_client_id
212 , actor_user_id
213 , scope_list
214 , created_at
215 , revoked_at
216 , human_name
217 , last_active_at
218 , last_active_ip as "last_active_ip: IpAddr"
219 FROM personal_sessions
220
221 WHERE personal_session_id = $1
222 "#,
223 Uuid::from(id),
224 )
225 .traced()
226 .fetch_optional(&mut *self.conn)
227 .await?;
228
229 let Some(session) = res else { return Ok(None) };
230
231 Ok(Some(session.try_into()?))
232 }
233
234 #[tracing::instrument(
235 name = "db.personal_session.add",
236 skip_all,
237 fields(
238 db.query.text,
239 session.id,
240 session.scope = %scope,
241 ),
242 err,
243 )]
244 async fn add(
245 &mut self,
246 rng: &mut (dyn RngCore + Send),
247 clock: &dyn Clock,
248 owner: PersonalSessionOwner,
249 actor_user: &User,
250 human_name: String,
251 scope: Scope,
252 ) -> Result<PersonalSession, Self::Error> {
253 let created_at = clock.now();
254 let id = Ulid::from_datetime_with_rng(created_at, rng);
255 tracing::Span::current().record("session.id", tracing::field::display(id));
256
257 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
258
259 let (owner_user_id, owner_oauth2_client_id) = match owner {
260 PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
261 PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
262 };
263
264 sqlx::query!(
265 r#"
266 INSERT INTO personal_sessions
267 ( personal_session_id
268 , owner_user_id
269 , owner_oauth2_client_id
270 , actor_user_id
271 , human_name
272 , scope_list
273 , created_at
274 )
275 VALUES ($1, $2, $3, $4, $5, $6, $7)
276 "#,
277 Uuid::from(id),
278 owner_user_id,
279 owner_oauth2_client_id,
280 Uuid::from(actor_user.id),
281 &human_name,
282 &scope_list,
283 created_at,
284 )
285 .traced()
286 .execute(&mut *self.conn)
287 .await?;
288
289 Ok(PersonalSession {
290 id,
291 state: SessionState::Valid,
292 owner,
293 actor_user_id: actor_user.id,
294 human_name,
295 scope,
296 created_at,
297 last_active_at: None,
298 last_active_ip: None,
299 })
300 }
301
302 #[tracing::instrument(
303 name = "db.personal_session.revoke",
304 skip_all,
305 fields(
306 db.query.text,
307 %session.id,
308 %session.scope,
309 ),
310 err,
311 )]
312 async fn revoke(
313 &mut self,
314 clock: &dyn Clock,
315 session: PersonalSession,
316 ) -> Result<PersonalSession, Self::Error> {
317 let revoked_at = clock.now();
318
319 {
320 let span = info_span!(
322 "db.personal_session.revoke.tokens",
323 { DB_QUERY_TEXT } = tracing::field::Empty,
324 );
325
326 sqlx::query!(
327 r#"
328 UPDATE personal_access_tokens
329 SET revoked_at = $2
330 WHERE personal_session_id = $1 AND revoked_at IS NULL
331 "#,
332 Uuid::from(session.id),
333 revoked_at,
334 )
335 .record(&span)
336 .execute(&mut *self.conn)
337 .instrument(span)
338 .await?;
339 }
340
341 let res = sqlx::query!(
342 r#"
343 UPDATE personal_sessions
344 SET revoked_at = $2
345 WHERE personal_session_id = $1
346 "#,
347 Uuid::from(session.id),
348 revoked_at,
349 )
350 .traced()
351 .execute(&mut *self.conn)
352 .await?;
353
354 DatabaseError::ensure_affected_rows(&res, 1)?;
355
356 session
357 .finish(revoked_at)
358 .map_err(DatabaseError::to_invalid_operation)
359 }
360
361 #[tracing::instrument(
362 name = "db.personal_session.revoke_bulk",
363 skip_all,
364 fields(
365 db.query.text,
366 ),
367 err,
368 )]
369 async fn revoke_bulk(
370 &mut self,
371 clock: &dyn Clock,
372 filter: PersonalSessionFilter<'_>,
373 ) -> Result<usize, Self::Error> {
374 let revoked_at = clock.now();
375
376 let (sql, arguments) = Query::update()
377 .table(PersonalSessions::Table)
378 .value(PersonalSessions::RevokedAt, revoked_at)
379 .and_where(
380 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
381 .in_subquery(
385 Query::select()
386 .expr(Expr::col((
387 PersonalSessions::Table,
388 PersonalSessions::PersonalSessionId,
389 )))
390 .from(PersonalSessions::Table)
391 .left_join(
392 PersonalAccessTokens::Table,
393 Cond::all()
394 .add(
396 Expr::col((
397 PersonalSessions::Table,
398 PersonalSessions::PersonalSessionId,
399 ))
400 .eq(Expr::col((
401 PersonalAccessTokens::Table,
402 PersonalAccessTokens::PersonalSessionId,
403 ))),
404 )
405 .add(
407 Expr::col((
408 PersonalAccessTokens::Table,
409 PersonalAccessTokens::RevokedAt,
410 ))
411 .is_null(),
412 ),
413 )
414 .apply_filter(filter)
415 .take(),
416 ),
417 )
418 .build_sqlx(PostgresQueryBuilder);
419
420 let res = sqlx::query_with(&sql, arguments)
421 .traced()
422 .execute(&mut *self.conn)
423 .await?;
424
425 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
426 }
427
428 #[tracing::instrument(
429 name = "db.personal_session.list",
430 skip_all,
431 fields(
432 db.query.text,
433 ),
434 err,
435 )]
436 async fn list(
437 &mut self,
438 filter: PersonalSessionFilter<'_>,
439 pagination: Pagination,
440 ) -> Result<Page<(PersonalSession, Option<PersonalAccessToken>)>, Self::Error> {
441 let (sql, arguments) = Query::select()
442 .expr_as(
443 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
444 PersonalSessionAndAccessTokenLookupIden::PersonalSessionId,
445 )
446 .expr_as(
447 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
448 PersonalSessionAndAccessTokenLookupIden::OwnerUserId,
449 )
450 .expr_as(
451 Expr::col((
452 PersonalSessions::Table,
453 PersonalSessions::OwnerOAuth2ClientId,
454 )),
455 PersonalSessionAndAccessTokenLookupIden::OwnerOauth2ClientId,
456 )
457 .expr_as(
458 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
459 PersonalSessionAndAccessTokenLookupIden::ActorUserId,
460 )
461 .expr_as(
462 Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
463 PersonalSessionAndAccessTokenLookupIden::HumanName,
464 )
465 .expr_as(
466 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
467 PersonalSessionAndAccessTokenLookupIden::ScopeList,
468 )
469 .expr_as(
470 Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
471 PersonalSessionAndAccessTokenLookupIden::CreatedAt,
472 )
473 .expr_as(
474 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
475 PersonalSessionAndAccessTokenLookupIden::RevokedAt,
476 )
477 .expr_as(
478 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
479 PersonalSessionAndAccessTokenLookupIden::LastActiveAt,
480 )
481 .expr_as(
482 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
483 PersonalSessionAndAccessTokenLookupIden::LastActiveIp,
484 )
485 .expr_as(
486 Expr::col((
487 PersonalAccessTokens::Table,
488 PersonalAccessTokens::PersonalAccessTokenId,
489 )),
490 PersonalSessionAndAccessTokenLookupIden::PersonalAccessTokenId,
491 )
492 .expr_as(
493 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::CreatedAt)),
494 PersonalSessionAndAccessTokenLookupIden::TokenCreatedAt,
495 )
496 .expr_as(
497 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt)),
498 PersonalSessionAndAccessTokenLookupIden::TokenExpiresAt,
499 )
500 .from(PersonalSessions::Table)
501 .left_join(
502 PersonalAccessTokens::Table,
503 Cond::all()
504 .add(
506 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
507 .eq(Expr::col((
508 PersonalAccessTokens::Table,
509 PersonalAccessTokens::PersonalSessionId,
510 ))),
511 )
512 .add(
514 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
515 .is_null(),
516 ),
517 )
518 .apply_filter(filter)
519 .generate_pagination(
520 (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
521 pagination,
522 )
523 .build_sqlx(PostgresQueryBuilder);
524
525 let edges: Vec<PersonalSessionAndAccessTokenLookup> = sqlx::query_as_with(&sql, arguments)
526 .traced()
527 .fetch_all(&mut *self.conn)
528 .await?;
529
530 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
531
532 Ok(page)
533 }
534
535 #[tracing::instrument(
536 name = "db.personal_session.count",
537 skip_all,
538 fields(
539 db.query.text,
540 ),
541 err,
542 )]
543 async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
544 let (sql, arguments) = Query::select()
545 .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
546 .from(PersonalSessions::Table)
547 .left_join(
548 PersonalAccessTokens::Table,
549 Cond::all()
550 .add(
552 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
553 .eq(Expr::col((
554 PersonalAccessTokens::Table,
555 PersonalAccessTokens::PersonalSessionId,
556 ))),
557 )
558 .add(
560 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
561 .is_null(),
562 ),
563 )
564 .apply_filter(filter)
565 .build_sqlx(PostgresQueryBuilder);
566
567 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
568 .traced()
569 .fetch_one(&mut *self.conn)
570 .await?;
571
572 count
573 .try_into()
574 .map_err(DatabaseError::to_invalid_operation)
575 }
576
577 #[tracing::instrument(
578 name = "db.personal_session.record_batch_activity",
579 skip_all,
580 fields(
581 db.query.text,
582 ),
583 err,
584 )]
585 async fn record_batch_activity(
586 &mut self,
587 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
588 ) -> Result<(), Self::Error> {
589 activities.sort_unstable();
592 let mut ids = Vec::with_capacity(activities.len());
593 let mut last_activities = Vec::with_capacity(activities.len());
594 let mut ips = Vec::with_capacity(activities.len());
595
596 for (id, last_activity, ip) in activities {
597 ids.push(Uuid::from(id));
598 last_activities.push(last_activity);
599 ips.push(ip);
600 }
601
602 let res = sqlx::query!(
603 r#"
604 UPDATE personal_sessions
605 SET last_active_at = GREATEST(t.last_active_at, personal_sessions.last_active_at)
606 , last_active_ip = COALESCE(t.last_active_ip, personal_sessions.last_active_ip)
607 FROM (
608 SELECT *
609 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
610 AS t(personal_session_id, last_active_at, last_active_ip)
611 ) AS t
612 WHERE personal_sessions.personal_session_id = t.personal_session_id
613 "#,
614 &ids,
615 &last_activities,
616 &ips as &[Option<IpAddr>],
617 )
618 .traced()
619 .execute(&mut *self.conn)
620 .await?;
621
622 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
623
624 Ok(())
625 }
626}
627
628impl Filter for PersonalSessionFilter<'_> {
629 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
630 sea_query::Condition::all()
631 .add_option(self.owner_user().map(|user| {
632 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
633 .eq(Uuid::from(user.id))
634 }))
635 .add_option(self.owner_oauth2_client().map(|client| {
636 Expr::col((
637 PersonalSessions::Table,
638 PersonalSessions::OwnerOAuth2ClientId,
639 ))
640 .eq(Uuid::from(client.id))
641 }))
642 .add_option(self.actor_user().map(|user| {
643 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
644 .eq(Uuid::from(user.id))
645 }))
646 .add_option(self.device().map(|device| -> SimpleExpr {
647 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
648 Condition::any()
649 .add(
650 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
651 PersonalSessions::Table,
652 PersonalSessions::ScopeList,
653 )))),
654 )
655 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
656 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
657 )))
658 .into()
659 } else {
660 Expr::val(false)
662 }
663 }))
664 .add_option(self.state().map(|state| match state {
665 PersonalSessionState::Active => {
666 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
667 }
668 PersonalSessionState::Revoked => {
669 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
670 }
671 }))
672 .add_option(self.scope().map(|scope| {
673 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
674 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
675 }))
676 .add_option(self.last_active_before().map(|last_active_before| {
677 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
678 .lt(last_active_before)
679 }))
680 .add_option(self.last_active_after().map(|last_active_after| {
681 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
682 .gt(last_active_after)
683 }))
684 .add_option(self.expires_before().map(|expires_before| {
685 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
686 .lt(expires_before)
687 }))
688 .add_option(self.expires_after().map(|expires_after| {
689 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
690 .gt(expires_after)
691 }))
692 .add_option(self.expires().map(|expires| {
693 let column =
694 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt));
695
696 if expires {
697 column.is_not_null()
698 } else {
699 column.is_null()
700 }
701 }))
702 }
703}