1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{BrowserSession, Clock, CompatSession, CompatSsoLogin, CompatSsoLoginState};
10use mas_storage::{
11 Page, Pagination,
12 compat::{CompatSsoLoginFilter, CompatSsoLoginRepository},
13 pagination::Node,
14};
15use rand::RngCore;
16use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
17use sea_query_binder::SqlxBinder;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use url::Url;
21use uuid::Uuid;
22
23use crate::{
24 DatabaseError, DatabaseInconsistencyError,
25 filter::{Filter, StatementExt},
26 iden::{CompatSsoLogins, UserSessions},
27 pagination::QueryBuilderExt,
28 tracing::ExecuteExt,
29};
30
31pub struct PgCompatSsoLoginRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgCompatSsoLoginRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(sqlx::FromRow, Debug)]
46#[enum_def]
47struct CompatSsoLoginLookup {
48 compat_sso_login_id: Uuid,
49 login_token: String,
50 redirect_uri: String,
51 created_at: DateTime<Utc>,
52 fulfilled_at: Option<DateTime<Utc>>,
53 exchanged_at: Option<DateTime<Utc>>,
54 user_session_id: Option<Uuid>,
55 compat_session_id: Option<Uuid>,
56}
57
58impl Node<Ulid> for CompatSsoLoginLookup {
59 fn cursor(&self) -> Ulid {
60 self.compat_sso_login_id.into()
61 }
62}
63
64impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
65 type Error = DatabaseInconsistencyError;
66
67 fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
68 let id = res.compat_sso_login_id.into();
69 let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
70 DatabaseInconsistencyError::on("compat_sso_logins")
71 .column("redirect_uri")
72 .row(id)
73 .source(e)
74 })?;
75
76 let state = match (
77 res.fulfilled_at,
78 res.exchanged_at,
79 res.user_session_id,
80 res.compat_session_id,
81 ) {
82 (None, None, None, None) => CompatSsoLoginState::Pending,
83 (Some(fulfilled_at), None, Some(browser_session_id), None) => {
84 CompatSsoLoginState::Fulfilled {
85 fulfilled_at,
86 browser_session_id: browser_session_id.into(),
87 }
88 }
89 (Some(fulfilled_at), Some(exchanged_at), _, Some(compat_session_id)) => {
90 CompatSsoLoginState::Exchanged {
91 fulfilled_at,
92 exchanged_at,
93 compat_session_id: compat_session_id.into(),
94 }
95 }
96 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
97 };
98
99 Ok(CompatSsoLogin {
100 id,
101 login_token: res.login_token,
102 redirect_uri,
103 created_at: res.created_at,
104 state,
105 })
106 }
107}
108
109impl Filter for CompatSsoLoginFilter<'_> {
110 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
111 sea_query::Condition::all()
112 .add_option(self.user().map(|user| {
113 Expr::exists(
114 Query::select()
115 .expr(Expr::cust("1"))
116 .from(UserSessions::Table)
117 .and_where(
118 Expr::col((UserSessions::Table, UserSessions::UserId))
119 .eq(Uuid::from(user.id)),
120 )
121 .and_where(
122 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId))
123 .equals((UserSessions::Table, UserSessions::UserSessionId)),
124 )
125 .take(),
126 )
127 }))
128 .add_option(self.state().map(|state| {
129 if state.is_exchanged() {
130 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
131 } else if state.is_fulfilled() {
132 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
133 .is_not_null()
134 .and(
135 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
136 .is_null(),
137 )
138 } else {
139 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
140 }
141 }))
142 }
143}
144
145#[async_trait]
146impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> {
147 type Error = DatabaseError;
148
149 #[tracing::instrument(
150 name = "db.compat_sso_login.lookup",
151 skip_all,
152 fields(
153 db.query.text,
154 compat_sso_login.id = %id,
155 ),
156 err,
157 )]
158 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
159 let res = sqlx::query_as!(
160 CompatSsoLoginLookup,
161 r#"
162 SELECT compat_sso_login_id
163 , login_token
164 , redirect_uri
165 , created_at
166 , fulfilled_at
167 , exchanged_at
168 , compat_session_id
169 , user_session_id
170
171 FROM compat_sso_logins
172 WHERE compat_sso_login_id = $1
173 "#,
174 Uuid::from(id),
175 )
176 .traced()
177 .fetch_optional(&mut *self.conn)
178 .await?;
179
180 let Some(res) = res else { return Ok(None) };
181
182 Ok(Some(res.try_into()?))
183 }
184
185 #[tracing::instrument(
186 name = "db.compat_sso_login.find_for_session",
187 skip_all,
188 fields(
189 db.query.text,
190 %compat_session.id,
191 ),
192 err,
193 )]
194 async fn find_for_session(
195 &mut self,
196 compat_session: &CompatSession,
197 ) -> Result<Option<CompatSsoLogin>, Self::Error> {
198 let res = sqlx::query_as!(
199 CompatSsoLoginLookup,
200 r#"
201 SELECT compat_sso_login_id
202 , login_token
203 , redirect_uri
204 , created_at
205 , fulfilled_at
206 , exchanged_at
207 , compat_session_id
208 , user_session_id
209
210 FROM compat_sso_logins
211 WHERE compat_session_id = $1
212 "#,
213 Uuid::from(compat_session.id),
214 )
215 .traced()
216 .fetch_optional(&mut *self.conn)
217 .await?;
218
219 let Some(res) = res else { return Ok(None) };
220
221 Ok(Some(res.try_into()?))
222 }
223
224 #[tracing::instrument(
225 name = "db.compat_sso_login.find_by_token",
226 skip_all,
227 fields(
228 db.query.text,
229 ),
230 err,
231 )]
232 async fn find_by_token(
233 &mut self,
234 login_token: &str,
235 ) -> Result<Option<CompatSsoLogin>, Self::Error> {
236 let res = sqlx::query_as!(
237 CompatSsoLoginLookup,
238 r#"
239 SELECT compat_sso_login_id
240 , login_token
241 , redirect_uri
242 , created_at
243 , fulfilled_at
244 , exchanged_at
245 , compat_session_id
246 , user_session_id
247
248 FROM compat_sso_logins
249 WHERE login_token = $1
250 "#,
251 login_token,
252 )
253 .traced()
254 .fetch_optional(&mut *self.conn)
255 .await?;
256
257 let Some(res) = res else { return Ok(None) };
258
259 Ok(Some(res.try_into()?))
260 }
261
262 #[tracing::instrument(
263 name = "db.compat_sso_login.add",
264 skip_all,
265 fields(
266 db.query.text,
267 compat_sso_login.id,
268 compat_sso_login.redirect_uri = %redirect_uri,
269 ),
270 err,
271 )]
272 async fn add(
273 &mut self,
274 rng: &mut (dyn RngCore + Send),
275 clock: &dyn Clock,
276 login_token: String,
277 redirect_uri: Url,
278 ) -> Result<CompatSsoLogin, Self::Error> {
279 let created_at = clock.now();
280 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
281 tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
282
283 sqlx::query!(
284 r#"
285 INSERT INTO compat_sso_logins
286 (compat_sso_login_id, login_token, redirect_uri, created_at)
287 VALUES ($1, $2, $3, $4)
288 "#,
289 Uuid::from(id),
290 &login_token,
291 redirect_uri.as_str(),
292 created_at,
293 )
294 .traced()
295 .execute(&mut *self.conn)
296 .await?;
297
298 Ok(CompatSsoLogin {
299 id,
300 login_token,
301 redirect_uri,
302 created_at,
303 state: CompatSsoLoginState::default(),
304 })
305 }
306
307 #[tracing::instrument(
308 name = "db.compat_sso_login.fulfill",
309 skip_all,
310 fields(
311 db.query.text,
312 %compat_sso_login.id,
313 %browser_session.id,
314 user.id = %browser_session.user.id,
315 ),
316 err,
317 )]
318 async fn fulfill(
319 &mut self,
320 clock: &dyn Clock,
321 compat_sso_login: CompatSsoLogin,
322 browser_session: &BrowserSession,
323 ) -> Result<CompatSsoLogin, Self::Error> {
324 let fulfilled_at = clock.now();
325 let compat_sso_login = compat_sso_login
326 .fulfill(fulfilled_at, browser_session)
327 .map_err(DatabaseError::to_invalid_operation)?;
328
329 let res = sqlx::query!(
330 r#"
331 UPDATE compat_sso_logins
332 SET
333 user_session_id = $2,
334 fulfilled_at = $3
335 WHERE
336 compat_sso_login_id = $1
337 "#,
338 Uuid::from(compat_sso_login.id),
339 Uuid::from(browser_session.id),
340 fulfilled_at,
341 )
342 .traced()
343 .execute(&mut *self.conn)
344 .await?;
345
346 DatabaseError::ensure_affected_rows(&res, 1)?;
347
348 Ok(compat_sso_login)
349 }
350
351 #[tracing::instrument(
352 name = "db.compat_sso_login.exchange",
353 skip_all,
354 fields(
355 db.query.text,
356 %compat_sso_login.id,
357 %compat_session.id,
358 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
359 ),
360 err,
361 )]
362 async fn exchange(
363 &mut self,
364 clock: &dyn Clock,
365 compat_sso_login: CompatSsoLogin,
366 compat_session: &CompatSession,
367 ) -> Result<CompatSsoLogin, Self::Error> {
368 let exchanged_at = clock.now();
369 let compat_sso_login = compat_sso_login
370 .exchange(exchanged_at, compat_session)
371 .map_err(DatabaseError::to_invalid_operation)?;
372
373 let res = sqlx::query!(
374 r#"
375 UPDATE compat_sso_logins
376 SET
377 exchanged_at = $2,
378 compat_session_id = $3
379 WHERE
380 compat_sso_login_id = $1
381 "#,
382 Uuid::from(compat_sso_login.id),
383 exchanged_at,
384 Uuid::from(compat_session.id),
385 )
386 .traced()
387 .execute(&mut *self.conn)
388 .await?;
389
390 DatabaseError::ensure_affected_rows(&res, 1)?;
391
392 Ok(compat_sso_login)
393 }
394
395 #[tracing::instrument(
396 name = "db.compat_sso_login.list",
397 skip_all,
398 fields(
399 db.query.text,
400 ),
401 err
402 )]
403 async fn list(
404 &mut self,
405 filter: CompatSsoLoginFilter<'_>,
406 pagination: Pagination,
407 ) -> Result<Page<CompatSsoLogin>, Self::Error> {
408 let (sql, arguments) = Query::select()
409 .expr_as(
410 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
411 CompatSsoLoginLookupIden::CompatSsoLoginId,
412 )
413 .expr_as(
414 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
415 CompatSsoLoginLookupIden::CompatSessionId,
416 )
417 .expr_as(
418 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId)),
419 CompatSsoLoginLookupIden::UserSessionId,
420 )
421 .expr_as(
422 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
423 CompatSsoLoginLookupIden::LoginToken,
424 )
425 .expr_as(
426 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
427 CompatSsoLoginLookupIden::RedirectUri,
428 )
429 .expr_as(
430 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
431 CompatSsoLoginLookupIden::CreatedAt,
432 )
433 .expr_as(
434 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
435 CompatSsoLoginLookupIden::FulfilledAt,
436 )
437 .expr_as(
438 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
439 CompatSsoLoginLookupIden::ExchangedAt,
440 )
441 .from(CompatSsoLogins::Table)
442 .apply_filter(filter)
443 .generate_pagination(
444 (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId),
445 pagination,
446 )
447 .build_sqlx(PostgresQueryBuilder);
448
449 let edges: Vec<CompatSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
450 .traced()
451 .fetch_all(&mut *self.conn)
452 .await?;
453
454 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
455
456 Ok(page)
457 }
458
459 #[tracing::instrument(
460 name = "db.compat_sso_login.count",
461 skip_all,
462 fields(
463 db.query.text,
464 ),
465 err
466 )]
467 async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error> {
468 let (sql, arguments) = Query::select()
469 .expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count())
470 .from(CompatSsoLogins::Table)
471 .apply_filter(filter)
472 .build_sqlx(PostgresQueryBuilder);
473
474 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
475 .traced()
476 .fetch_one(&mut *self.conn)
477 .await?;
478
479 count
480 .try_into()
481 .map_err(DatabaseError::to_invalid_operation)
482 }
483}