1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{BrowserSession, CompatSession, CompatSsoLogin, CompatSsoLoginState};
10use mas_storage::{
11 Clock, Page, Pagination,
12 compat::{CompatSsoLoginFilter, CompatSsoLoginRepository},
13};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use url::Url;
20use uuid::Uuid;
21
22use crate::{
23 DatabaseError, DatabaseInconsistencyError,
24 filter::{Filter, StatementExt},
25 iden::{CompatSsoLogins, UserSessions},
26 pagination::QueryBuilderExt,
27 tracing::ExecuteExt,
28};
29
30pub struct PgCompatSsoLoginRepository<'c> {
33 conn: &'c mut PgConnection,
34}
35
36impl<'c> PgCompatSsoLoginRepository<'c> {
37 pub fn new(conn: &'c mut PgConnection) -> Self {
40 Self { conn }
41 }
42}
43
44#[derive(sqlx::FromRow, Debug)]
45#[enum_def]
46struct CompatSsoLoginLookup {
47 compat_sso_login_id: Uuid,
48 login_token: String,
49 redirect_uri: String,
50 created_at: DateTime<Utc>,
51 fulfilled_at: Option<DateTime<Utc>>,
52 exchanged_at: Option<DateTime<Utc>>,
53 user_session_id: Option<Uuid>,
54 compat_session_id: Option<Uuid>,
55}
56
57impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
58 type Error = DatabaseInconsistencyError;
59
60 fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
61 let id = res.compat_sso_login_id.into();
62 let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
63 DatabaseInconsistencyError::on("compat_sso_logins")
64 .column("redirect_uri")
65 .row(id)
66 .source(e)
67 })?;
68
69 let state = match (
70 res.fulfilled_at,
71 res.exchanged_at,
72 res.user_session_id,
73 res.compat_session_id,
74 ) {
75 (None, None, None, None) => CompatSsoLoginState::Pending,
76 (Some(fulfilled_at), None, Some(browser_session_id), None) => {
77 CompatSsoLoginState::Fulfilled {
78 fulfilled_at,
79 browser_session_id: browser_session_id.into(),
80 }
81 }
82 (Some(fulfilled_at), Some(exchanged_at), _, Some(compat_session_id)) => {
83 CompatSsoLoginState::Exchanged {
84 fulfilled_at,
85 exchanged_at,
86 compat_session_id: compat_session_id.into(),
87 }
88 }
89 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
90 };
91
92 Ok(CompatSsoLogin {
93 id,
94 login_token: res.login_token,
95 redirect_uri,
96 created_at: res.created_at,
97 state,
98 })
99 }
100}
101
102impl Filter for CompatSsoLoginFilter<'_> {
103 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
104 sea_query::Condition::all()
105 .add_option(self.user().map(|user| {
106 Expr::exists(
107 Query::select()
108 .expr(Expr::cust("1"))
109 .from(UserSessions::Table)
110 .and_where(
111 Expr::col((UserSessions::Table, UserSessions::UserId))
112 .eq(Uuid::from(user.id)),
113 )
114 .and_where(
115 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId))
116 .equals((UserSessions::Table, UserSessions::UserSessionId)),
117 )
118 .take(),
119 )
120 }))
121 .add_option(self.state().map(|state| {
122 if state.is_exchanged() {
123 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
124 } else if state.is_fulfilled() {
125 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
126 .is_not_null()
127 .and(
128 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
129 .is_null(),
130 )
131 } else {
132 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
133 }
134 }))
135 }
136}
137
138#[async_trait]
139impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> {
140 type Error = DatabaseError;
141
142 #[tracing::instrument(
143 name = "db.compat_sso_login.lookup",
144 skip_all,
145 fields(
146 db.query.text,
147 compat_sso_login.id = %id,
148 ),
149 err,
150 )]
151 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
152 let res = sqlx::query_as!(
153 CompatSsoLoginLookup,
154 r#"
155 SELECT compat_sso_login_id
156 , login_token
157 , redirect_uri
158 , created_at
159 , fulfilled_at
160 , exchanged_at
161 , compat_session_id
162 , user_session_id
163
164 FROM compat_sso_logins
165 WHERE compat_sso_login_id = $1
166 "#,
167 Uuid::from(id),
168 )
169 .traced()
170 .fetch_optional(&mut *self.conn)
171 .await?;
172
173 let Some(res) = res else { return Ok(None) };
174
175 Ok(Some(res.try_into()?))
176 }
177
178 #[tracing::instrument(
179 name = "db.compat_sso_login.find_for_session",
180 skip_all,
181 fields(
182 db.query.text,
183 %compat_session.id,
184 ),
185 err,
186 )]
187 async fn find_for_session(
188 &mut self,
189 compat_session: &CompatSession,
190 ) -> Result<Option<CompatSsoLogin>, Self::Error> {
191 let res = sqlx::query_as!(
192 CompatSsoLoginLookup,
193 r#"
194 SELECT compat_sso_login_id
195 , login_token
196 , redirect_uri
197 , created_at
198 , fulfilled_at
199 , exchanged_at
200 , compat_session_id
201 , user_session_id
202
203 FROM compat_sso_logins
204 WHERE compat_session_id = $1
205 "#,
206 Uuid::from(compat_session.id),
207 )
208 .traced()
209 .fetch_optional(&mut *self.conn)
210 .await?;
211
212 let Some(res) = res else { return Ok(None) };
213
214 Ok(Some(res.try_into()?))
215 }
216
217 #[tracing::instrument(
218 name = "db.compat_sso_login.find_by_token",
219 skip_all,
220 fields(
221 db.query.text,
222 ),
223 err,
224 )]
225 async fn find_by_token(
226 &mut self,
227 login_token: &str,
228 ) -> Result<Option<CompatSsoLogin>, Self::Error> {
229 let res = sqlx::query_as!(
230 CompatSsoLoginLookup,
231 r#"
232 SELECT compat_sso_login_id
233 , login_token
234 , redirect_uri
235 , created_at
236 , fulfilled_at
237 , exchanged_at
238 , compat_session_id
239 , user_session_id
240
241 FROM compat_sso_logins
242 WHERE login_token = $1
243 "#,
244 login_token,
245 )
246 .traced()
247 .fetch_optional(&mut *self.conn)
248 .await?;
249
250 let Some(res) = res else { return Ok(None) };
251
252 Ok(Some(res.try_into()?))
253 }
254
255 #[tracing::instrument(
256 name = "db.compat_sso_login.add",
257 skip_all,
258 fields(
259 db.query.text,
260 compat_sso_login.id,
261 compat_sso_login.redirect_uri = %redirect_uri,
262 ),
263 err,
264 )]
265 async fn add(
266 &mut self,
267 rng: &mut (dyn RngCore + Send),
268 clock: &dyn Clock,
269 login_token: String,
270 redirect_uri: Url,
271 ) -> Result<CompatSsoLogin, Self::Error> {
272 let created_at = clock.now();
273 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
274 tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
275
276 sqlx::query!(
277 r#"
278 INSERT INTO compat_sso_logins
279 (compat_sso_login_id, login_token, redirect_uri, created_at)
280 VALUES ($1, $2, $3, $4)
281 "#,
282 Uuid::from(id),
283 &login_token,
284 redirect_uri.as_str(),
285 created_at,
286 )
287 .traced()
288 .execute(&mut *self.conn)
289 .await?;
290
291 Ok(CompatSsoLogin {
292 id,
293 login_token,
294 redirect_uri,
295 created_at,
296 state: CompatSsoLoginState::default(),
297 })
298 }
299
300 #[tracing::instrument(
301 name = "db.compat_sso_login.fulfill",
302 skip_all,
303 fields(
304 db.query.text,
305 %compat_sso_login.id,
306 %browser_session.id,
307 user.id = %browser_session.user.id,
308 ),
309 err,
310 )]
311 async fn fulfill(
312 &mut self,
313 clock: &dyn Clock,
314 compat_sso_login: CompatSsoLogin,
315 browser_session: &BrowserSession,
316 ) -> Result<CompatSsoLogin, Self::Error> {
317 let fulfilled_at = clock.now();
318 let compat_sso_login = compat_sso_login
319 .fulfill(fulfilled_at, browser_session)
320 .map_err(DatabaseError::to_invalid_operation)?;
321
322 let res = sqlx::query!(
323 r#"
324 UPDATE compat_sso_logins
325 SET
326 user_session_id = $2,
327 fulfilled_at = $3
328 WHERE
329 compat_sso_login_id = $1
330 "#,
331 Uuid::from(compat_sso_login.id),
332 Uuid::from(browser_session.id),
333 fulfilled_at,
334 )
335 .traced()
336 .execute(&mut *self.conn)
337 .await?;
338
339 DatabaseError::ensure_affected_rows(&res, 1)?;
340
341 Ok(compat_sso_login)
342 }
343
344 #[tracing::instrument(
345 name = "db.compat_sso_login.exchange",
346 skip_all,
347 fields(
348 db.query.text,
349 %compat_sso_login.id,
350 %compat_session.id,
351 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
352 ),
353 err,
354 )]
355 async fn exchange(
356 &mut self,
357 clock: &dyn Clock,
358 compat_sso_login: CompatSsoLogin,
359 compat_session: &CompatSession,
360 ) -> Result<CompatSsoLogin, Self::Error> {
361 let exchanged_at = clock.now();
362 let compat_sso_login = compat_sso_login
363 .exchange(exchanged_at, compat_session)
364 .map_err(DatabaseError::to_invalid_operation)?;
365
366 let res = sqlx::query!(
367 r#"
368 UPDATE compat_sso_logins
369 SET
370 exchanged_at = $2,
371 compat_session_id = $3
372 WHERE
373 compat_sso_login_id = $1
374 "#,
375 Uuid::from(compat_sso_login.id),
376 exchanged_at,
377 Uuid::from(compat_session.id),
378 )
379 .traced()
380 .execute(&mut *self.conn)
381 .await?;
382
383 DatabaseError::ensure_affected_rows(&res, 1)?;
384
385 Ok(compat_sso_login)
386 }
387
388 #[tracing::instrument(
389 name = "db.compat_sso_login.list",
390 skip_all,
391 fields(
392 db.query.text,
393 ),
394 err
395 )]
396 async fn list(
397 &mut self,
398 filter: CompatSsoLoginFilter<'_>,
399 pagination: Pagination,
400 ) -> Result<Page<CompatSsoLogin>, Self::Error> {
401 let (sql, arguments) = Query::select()
402 .expr_as(
403 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
404 CompatSsoLoginLookupIden::CompatSsoLoginId,
405 )
406 .expr_as(
407 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
408 CompatSsoLoginLookupIden::CompatSessionId,
409 )
410 .expr_as(
411 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId)),
412 CompatSsoLoginLookupIden::UserSessionId,
413 )
414 .expr_as(
415 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
416 CompatSsoLoginLookupIden::LoginToken,
417 )
418 .expr_as(
419 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
420 CompatSsoLoginLookupIden::RedirectUri,
421 )
422 .expr_as(
423 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
424 CompatSsoLoginLookupIden::CreatedAt,
425 )
426 .expr_as(
427 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
428 CompatSsoLoginLookupIden::FulfilledAt,
429 )
430 .expr_as(
431 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
432 CompatSsoLoginLookupIden::ExchangedAt,
433 )
434 .from(CompatSsoLogins::Table)
435 .apply_filter(filter)
436 .generate_pagination(
437 (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId),
438 pagination,
439 )
440 .build_sqlx(PostgresQueryBuilder);
441
442 let edges: Vec<CompatSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
443 .traced()
444 .fetch_all(&mut *self.conn)
445 .await?;
446
447 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
448
449 Ok(page)
450 }
451
452 #[tracing::instrument(
453 name = "db.compat_sso_login.count",
454 skip_all,
455 fields(
456 db.query.text,
457 ),
458 err
459 )]
460 async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error> {
461 let (sql, arguments) = Query::select()
462 .expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count())
463 .from(CompatSsoLogins::Table)
464 .apply_filter(filter)
465 .build_sqlx(PostgresQueryBuilder);
466
467 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
468 .traced()
469 .fetch_one(&mut *self.conn)
470 .await?;
471
472 count
473 .try_into()
474 .map_err(DatabaseError::to_invalid_operation)
475 }
476}