1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
12use mas_storage::{
13 Clock, Page, Pagination,
14 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError, DatabaseInconsistencyError,
26 filter::{Filter, StatementExt},
27 iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32pub struct PgOAuth2SessionRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgOAuth2SessionRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct OAuthSessionLookup {
48 oauth2_session_id: Uuid,
49 user_id: Option<Uuid>,
50 user_session_id: Option<Uuid>,
51 oauth2_client_id: Uuid,
52 scope_list: Vec<String>,
53 created_at: DateTime<Utc>,
54 finished_at: Option<DateTime<Utc>>,
55 user_agent: Option<String>,
56 last_active_at: Option<DateTime<Utc>>,
57 last_active_ip: Option<IpAddr>,
58 human_name: Option<String>,
59}
60
61impl TryFrom<OAuthSessionLookup> for Session {
62 type Error = DatabaseInconsistencyError;
63
64 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
65 let id = Ulid::from(value.oauth2_session_id);
66 let scope: Result<Scope, _> = value
67 .scope_list
68 .iter()
69 .map(|s| s.parse::<ScopeToken>())
70 .collect();
71 let scope = scope.map_err(|e| {
72 DatabaseInconsistencyError::on("oauth2_sessions")
73 .column("scope")
74 .row(id)
75 .source(e)
76 })?;
77
78 let state = match value.finished_at {
79 None => SessionState::Valid,
80 Some(finished_at) => SessionState::Finished { finished_at },
81 };
82
83 Ok(Session {
84 id,
85 state,
86 created_at: value.created_at,
87 client_id: value.oauth2_client_id.into(),
88 user_id: value.user_id.map(Ulid::from),
89 user_session_id: value.user_session_id.map(Ulid::from),
90 scope,
91 user_agent: value.user_agent,
92 last_active_at: value.last_active_at,
93 last_active_ip: value.last_active_ip,
94 human_name: value.human_name,
95 })
96 }
97}
98
99impl Filter for OAuth2SessionFilter<'_> {
100 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
101 sea_query::Condition::all()
102 .add_option(self.user().map(|user| {
103 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
104 }))
105 .add_option(self.client().map(|client| {
106 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
107 .eq(Uuid::from(client.id))
108 }))
109 .add_option(self.client_kind().map(|client_kind| {
110 let static_clients = Query::select()
114 .expr(Expr::col((
115 OAuth2Clients::Table,
116 OAuth2Clients::OAuth2ClientId,
117 )))
118 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
119 .from(OAuth2Clients::Table)
120 .take();
121 if client_kind.is_static() {
122 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
123 .eq(Expr::any(static_clients))
124 } else {
125 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
126 .ne(Expr::all(static_clients))
127 }
128 }))
129 .add_option(self.device().map(|device| {
130 if let Ok(scope_token) = device.to_scope_token() {
131 Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
132 OAuth2Sessions::Table,
133 OAuth2Sessions::ScopeList,
134 ))))
135 } else {
136 Expr::val(false).into()
138 }
139 }))
140 .add_option(self.browser_session().map(|browser_session| {
141 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
142 .eq(Uuid::from(browser_session.id))
143 }))
144 .add_option(self.browser_session_filter().map(|browser_session_filter| {
145 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
146 Query::select()
147 .expr(Expr::col((
148 UserSessions::Table,
149 UserSessions::UserSessionId,
150 )))
151 .apply_filter(browser_session_filter)
152 .from(UserSessions::Table)
153 .take(),
154 )
155 }))
156 .add_option(self.state().map(|state| {
157 if state.is_active() {
158 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
159 } else {
160 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
161 }
162 }))
163 .add_option(self.scope().map(|scope| {
164 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
165 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
166 }))
167 .add_option(self.any_user().map(|any_user| {
168 if any_user {
169 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
170 } else {
171 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
172 }
173 }))
174 .add_option(self.last_active_after().map(|last_active_after| {
175 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
176 .gt(last_active_after)
177 }))
178 .add_option(self.last_active_before().map(|last_active_before| {
179 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
180 .lt(last_active_before)
181 }))
182 }
183}
184
185#[async_trait]
186impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
187 type Error = DatabaseError;
188
189 #[tracing::instrument(
190 name = "db.oauth2_session.lookup",
191 skip_all,
192 fields(
193 db.query.text,
194 session.id = %id,
195 ),
196 err,
197 )]
198 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
199 let res = sqlx::query_as!(
200 OAuthSessionLookup,
201 r#"
202 SELECT oauth2_session_id
203 , user_id
204 , user_session_id
205 , oauth2_client_id
206 , scope_list
207 , created_at
208 , finished_at
209 , user_agent
210 , last_active_at
211 , last_active_ip as "last_active_ip: IpAddr"
212 , human_name
213 FROM oauth2_sessions
214
215 WHERE oauth2_session_id = $1
216 "#,
217 Uuid::from(id),
218 )
219 .traced()
220 .fetch_optional(&mut *self.conn)
221 .await?;
222
223 let Some(session) = res else { return Ok(None) };
224
225 Ok(Some(session.try_into()?))
226 }
227
228 #[tracing::instrument(
229 name = "db.oauth2_session.add",
230 skip_all,
231 fields(
232 db.query.text,
233 %client.id,
234 session.id,
235 session.scope = %scope,
236 ),
237 err,
238 )]
239 async fn add(
240 &mut self,
241 rng: &mut (dyn RngCore + Send),
242 clock: &dyn Clock,
243 client: &Client,
244 user: Option<&User>,
245 user_session: Option<&BrowserSession>,
246 scope: Scope,
247 ) -> Result<Session, Self::Error> {
248 let created_at = clock.now();
249 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
250 tracing::Span::current().record("session.id", tracing::field::display(id));
251
252 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
253
254 sqlx::query!(
255 r#"
256 INSERT INTO oauth2_sessions
257 ( oauth2_session_id
258 , user_id
259 , user_session_id
260 , oauth2_client_id
261 , scope_list
262 , created_at
263 )
264 VALUES ($1, $2, $3, $4, $5, $6)
265 "#,
266 Uuid::from(id),
267 user.map(|u| Uuid::from(u.id)),
268 user_session.map(|s| Uuid::from(s.id)),
269 Uuid::from(client.id),
270 &scope_list,
271 created_at,
272 )
273 .traced()
274 .execute(&mut *self.conn)
275 .await?;
276
277 Ok(Session {
278 id,
279 state: SessionState::Valid,
280 created_at,
281 user_id: user.map(|u| u.id),
282 user_session_id: user_session.map(|s| s.id),
283 client_id: client.id,
284 scope,
285 user_agent: None,
286 last_active_at: None,
287 last_active_ip: None,
288 human_name: None,
289 })
290 }
291
292 #[tracing::instrument(
293 name = "db.oauth2_session.finish_bulk",
294 skip_all,
295 fields(
296 db.query.text,
297 ),
298 err,
299 )]
300 async fn finish_bulk(
301 &mut self,
302 clock: &dyn Clock,
303 filter: OAuth2SessionFilter<'_>,
304 ) -> Result<usize, Self::Error> {
305 let finished_at = clock.now();
306 let (sql, arguments) = Query::update()
307 .table(OAuth2Sessions::Table)
308 .value(OAuth2Sessions::FinishedAt, finished_at)
309 .apply_filter(filter)
310 .build_sqlx(PostgresQueryBuilder);
311
312 let res = sqlx::query_with(&sql, arguments)
313 .traced()
314 .execute(&mut *self.conn)
315 .await?;
316
317 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
318 }
319
320 #[tracing::instrument(
321 name = "db.oauth2_session.finish",
322 skip_all,
323 fields(
324 db.query.text,
325 %session.id,
326 %session.scope,
327 client.id = %session.client_id,
328 ),
329 err,
330 )]
331 async fn finish(
332 &mut self,
333 clock: &dyn Clock,
334 session: Session,
335 ) -> Result<Session, Self::Error> {
336 let finished_at = clock.now();
337 let res = sqlx::query!(
338 r#"
339 UPDATE oauth2_sessions
340 SET finished_at = $2
341 WHERE oauth2_session_id = $1
342 "#,
343 Uuid::from(session.id),
344 finished_at,
345 )
346 .traced()
347 .execute(&mut *self.conn)
348 .await?;
349
350 DatabaseError::ensure_affected_rows(&res, 1)?;
351
352 session
353 .finish(finished_at)
354 .map_err(DatabaseError::to_invalid_operation)
355 }
356
357 #[tracing::instrument(
358 name = "db.oauth2_session.list",
359 skip_all,
360 fields(
361 db.query.text,
362 ),
363 err,
364 )]
365 async fn list(
366 &mut self,
367 filter: OAuth2SessionFilter<'_>,
368 pagination: Pagination,
369 ) -> Result<Page<Session>, Self::Error> {
370 let (sql, arguments) = Query::select()
371 .expr_as(
372 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
373 OAuthSessionLookupIden::Oauth2SessionId,
374 )
375 .expr_as(
376 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
377 OAuthSessionLookupIden::UserId,
378 )
379 .expr_as(
380 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
381 OAuthSessionLookupIden::UserSessionId,
382 )
383 .expr_as(
384 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
385 OAuthSessionLookupIden::Oauth2ClientId,
386 )
387 .expr_as(
388 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
389 OAuthSessionLookupIden::ScopeList,
390 )
391 .expr_as(
392 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
393 OAuthSessionLookupIden::CreatedAt,
394 )
395 .expr_as(
396 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
397 OAuthSessionLookupIden::FinishedAt,
398 )
399 .expr_as(
400 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
401 OAuthSessionLookupIden::UserAgent,
402 )
403 .expr_as(
404 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
405 OAuthSessionLookupIden::LastActiveAt,
406 )
407 .expr_as(
408 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
409 OAuthSessionLookupIden::LastActiveIp,
410 )
411 .expr_as(
412 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
413 OAuthSessionLookupIden::HumanName,
414 )
415 .from(OAuth2Sessions::Table)
416 .apply_filter(filter)
417 .generate_pagination(
418 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
419 pagination,
420 )
421 .build_sqlx(PostgresQueryBuilder);
422
423 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
424 .traced()
425 .fetch_all(&mut *self.conn)
426 .await?;
427
428 let page = pagination.process(edges).try_map(Session::try_from)?;
429
430 Ok(page)
431 }
432
433 #[tracing::instrument(
434 name = "db.oauth2_session.count",
435 skip_all,
436 fields(
437 db.query.text,
438 ),
439 err,
440 )]
441 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
442 let (sql, arguments) = Query::select()
443 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
444 .from(OAuth2Sessions::Table)
445 .apply_filter(filter)
446 .build_sqlx(PostgresQueryBuilder);
447
448 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
449 .traced()
450 .fetch_one(&mut *self.conn)
451 .await?;
452
453 count
454 .try_into()
455 .map_err(DatabaseError::to_invalid_operation)
456 }
457
458 #[tracing::instrument(
459 name = "db.oauth2_session.record_batch_activity",
460 skip_all,
461 fields(
462 db.query.text,
463 ),
464 err,
465 )]
466 async fn record_batch_activity(
467 &mut self,
468 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
469 ) -> Result<(), Self::Error> {
470 activities.sort_unstable();
473 let mut ids = Vec::with_capacity(activities.len());
474 let mut last_activities = Vec::with_capacity(activities.len());
475 let mut ips = Vec::with_capacity(activities.len());
476
477 for (id, last_activity, ip) in activities {
478 ids.push(Uuid::from(id));
479 last_activities.push(last_activity);
480 ips.push(ip);
481 }
482
483 let res = sqlx::query!(
484 r#"
485 UPDATE oauth2_sessions
486 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
487 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
488 FROM (
489 SELECT *
490 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
491 AS t(oauth2_session_id, last_active_at, last_active_ip)
492 ) AS t
493 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
494 "#,
495 &ids,
496 &last_activities,
497 &ips as &[Option<IpAddr>],
498 )
499 .traced()
500 .execute(&mut *self.conn)
501 .await?;
502
503 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
504
505 Ok(())
506 }
507
508 #[tracing::instrument(
509 name = "db.oauth2_session.record_user_agent",
510 skip_all,
511 fields(
512 db.query.text,
513 %session.id,
514 %session.scope,
515 client.id = %session.client_id,
516 session.user_agent = user_agent,
517 ),
518 err,
519 )]
520 async fn record_user_agent(
521 &mut self,
522 mut session: Session,
523 user_agent: String,
524 ) -> Result<Session, Self::Error> {
525 let res = sqlx::query!(
526 r#"
527 UPDATE oauth2_sessions
528 SET user_agent = $2
529 WHERE oauth2_session_id = $1
530 "#,
531 Uuid::from(session.id),
532 &*user_agent,
533 )
534 .traced()
535 .execute(&mut *self.conn)
536 .await?;
537
538 session.user_agent = Some(user_agent);
539
540 DatabaseError::ensure_affected_rows(&res, 1)?;
541
542 Ok(session)
543 }
544
545 #[tracing::instrument(
546 name = "repository.oauth2_session.set_human_name",
547 skip(self),
548 fields(
549 client.id = %session.client_id,
550 session.human_name = ?human_name,
551 ),
552 err,
553 )]
554 async fn set_human_name(
555 &mut self,
556 mut session: Session,
557 human_name: Option<String>,
558 ) -> Result<Session, Self::Error> {
559 let res = sqlx::query!(
560 r#"
561 UPDATE oauth2_sessions
562 SET human_name = $2
563 WHERE oauth2_session_id = $1
564 "#,
565 Uuid::from(session.id),
566 human_name.as_deref(),
567 )
568 .traced()
569 .execute(&mut *self.conn)
570 .await?;
571
572 session.human_name = human_name;
573
574 DatabaseError::ensure_affected_rows(&res, 1)?;
575
576 Ok(session)
577 }
578}