1use std::collections::BTreeMap;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13 AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Clock, Pkce, Session,
14 UlidExt as _,
15};
16use mas_iana::oauth::PkceCodeChallengeMethod;
17use mas_storage::oauth2::OAuth2AuthorizationGrantRepository;
18use oauth2_types::{requests::ResponseMode, scope::Scope};
19use rand::RngCore;
20use sqlx::{PgConnection, types::Json};
21use ulid::Ulid;
22use url::Url;
23use uuid::Uuid;
24
25use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
26
27pub struct PgOAuth2AuthorizationGrantRepository<'c> {
30 conn: &'c mut PgConnection,
31}
32
33impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
34 pub fn new(conn: &'c mut PgConnection) -> Self {
37 Self { conn }
38 }
39}
40
41struct GrantLookup {
42 oauth2_authorization_grant_id: Uuid,
43 created_at: DateTime<Utc>,
44 cancelled_at: Option<DateTime<Utc>>,
45 fulfilled_at: Option<DateTime<Utc>>,
46 exchanged_at: Option<DateTime<Utc>>,
47 scope: String,
48 state: Option<String>,
49 nonce: Option<String>,
50 redirect_uri: String,
51 response_mode: String,
52 response_type_code: bool,
53 response_type_id_token: bool,
54 authorization_code: Option<String>,
55 code_challenge: Option<String>,
56 code_challenge_method: Option<String>,
57 login_hint: Option<String>,
58 locale: Option<String>,
59 raw_parameters: Option<Json<BTreeMap<String, String>>>,
60 oauth2_client_id: Uuid,
61 oauth2_session_id: Option<Uuid>,
62}
63
64impl TryFrom<GrantLookup> for AuthorizationGrant {
65 type Error = DatabaseInconsistencyError;
66
67 fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
68 let id = value.oauth2_authorization_grant_id.into();
69 let scope: Scope = value.scope.parse().map_err(|e| {
70 DatabaseInconsistencyError::on("oauth2_authorization_grants")
71 .column("scope")
72 .row(id)
73 .source(e)
74 })?;
75
76 let stage = match (
77 value.fulfilled_at,
78 value.exchanged_at,
79 value.cancelled_at,
80 value.oauth2_session_id,
81 ) {
82 (None, None, None, None) => AuthorizationGrantStage::Pending,
83 (Some(fulfilled_at), None, None, Some(session_id)) => {
84 AuthorizationGrantStage::Fulfilled {
85 session_id: session_id.into(),
86 fulfilled_at,
87 }
88 }
89 (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
90 AuthorizationGrantStage::Exchanged {
91 session_id: session_id.into(),
92 fulfilled_at,
93 exchanged_at,
94 }
95 }
96 (None, None, Some(cancelled_at), None) => {
97 AuthorizationGrantStage::Cancelled { cancelled_at }
98 }
99 _ => {
100 return Err(
101 DatabaseInconsistencyError::on("oauth2_authorization_grants")
102 .column("stage")
103 .row(id),
104 );
105 }
106 };
107
108 let pkce = match (value.code_challenge, value.code_challenge_method) {
109 (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
110 Some(Pkce {
111 challenge_method: PkceCodeChallengeMethod::Plain,
112 challenge,
113 })
114 }
115 (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
116 challenge_method: PkceCodeChallengeMethod::S256,
117 challenge,
118 }),
119 (None, None) => None,
120 _ => {
121 return Err(
122 DatabaseInconsistencyError::on("oauth2_authorization_grants")
123 .column("code_challenge_method")
124 .row(id),
125 );
126 }
127 };
128
129 let code: Option<AuthorizationCode> =
130 match (value.response_type_code, value.authorization_code, pkce) {
131 (false, None, None) => None,
132 (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
133 _ => {
134 return Err(
135 DatabaseInconsistencyError::on("oauth2_authorization_grants")
136 .column("authorization_code")
137 .row(id),
138 );
139 }
140 };
141
142 let redirect_uri = value.redirect_uri.parse().map_err(|e| {
143 DatabaseInconsistencyError::on("oauth2_authorization_grants")
144 .column("redirect_uri")
145 .row(id)
146 .source(e)
147 })?;
148
149 let response_mode = value.response_mode.parse().map_err(|e| {
150 DatabaseInconsistencyError::on("oauth2_authorization_grants")
151 .column("response_mode")
152 .row(id)
153 .source(e)
154 })?;
155
156 Ok(AuthorizationGrant {
157 id,
158 stage,
159 client_id: value.oauth2_client_id.into(),
160 code,
161 scope,
162 state: value.state,
163 nonce: value.nonce,
164 response_mode,
165 redirect_uri,
166 created_at: value.created_at,
167 response_type_id_token: value.response_type_id_token,
168 login_hint: value.login_hint,
169 locale: value.locale,
170 raw_parameters: value.raw_parameters.map(|Json(x)| x).unwrap_or_default(),
171 })
172 }
173}
174
175#[async_trait]
176impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
177 type Error = DatabaseError;
178
179 #[tracing::instrument(
180 name = "db.oauth2_authorization_grant.add",
181 skip_all,
182 fields(
183 db.query.text,
184 grant.id,
185 grant.scope = %scope,
186 %client.id,
187 ),
188 err,
189 )]
190 async fn add(
191 &mut self,
192 rng: &mut (dyn RngCore + Send),
193 clock: &dyn Clock,
194 client: &Client,
195 redirect_uri: Url,
196 scope: Scope,
197 code: Option<AuthorizationCode>,
198 state: Option<String>,
199 nonce: Option<String>,
200 response_mode: ResponseMode,
201 response_type_id_token: bool,
202 login_hint: Option<String>,
203 locale: Option<String>,
204 raw_parameters: BTreeMap<String, String>,
205 ) -> Result<AuthorizationGrant, Self::Error> {
206 let code_challenge = code
207 .as_ref()
208 .and_then(|c| c.pkce.as_ref())
209 .map(|p| &p.challenge);
210 let code_challenge_method = code
211 .as_ref()
212 .and_then(|c| c.pkce.as_ref())
213 .map(|p| p.challenge_method.to_string());
214 let code_str = code.as_ref().map(|c| &c.code);
215
216 let created_at = clock.now();
217 let id = Ulid::from_datetime_with_rng(created_at, rng);
218 tracing::Span::current().record("grant.id", tracing::field::display(id));
219
220 sqlx::query!(
221 r#"
222 INSERT INTO oauth2_authorization_grants (
223 oauth2_authorization_grant_id,
224 oauth2_client_id,
225 redirect_uri,
226 scope,
227 state,
228 nonce,
229 response_mode,
230 code_challenge,
231 code_challenge_method,
232 response_type_code,
233 response_type_id_token,
234 authorization_code,
235 login_hint,
236 locale,
237 raw_parameters,
238 created_at
239 )
240 VALUES
241 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
242 "#,
243 Uuid::from(id),
244 Uuid::from(client.id),
245 redirect_uri.to_string(),
246 scope.to_string(),
247 state,
248 nonce,
249 response_mode.to_string(),
250 code_challenge,
251 code_challenge_method,
252 code.is_some(),
253 response_type_id_token,
254 code_str,
255 login_hint,
256 locale,
257 Json(&raw_parameters) as _,
258 created_at,
259 )
260 .traced()
261 .execute(&mut *self.conn)
262 .await?;
263
264 Ok(AuthorizationGrant {
265 id,
266 stage: AuthorizationGrantStage::Pending,
267 code,
268 redirect_uri,
269 client_id: client.id,
270 scope,
271 state,
272 nonce,
273 response_mode,
274 created_at,
275 response_type_id_token,
276 login_hint,
277 locale,
278 raw_parameters,
279 })
280 }
281
282 #[tracing::instrument(
283 name = "db.oauth2_authorization_grant.lookup",
284 skip_all,
285 fields(
286 db.query.text,
287 grant.id = %id,
288 ),
289 err,
290 )]
291 async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
292 let res = sqlx::query_as!(
293 GrantLookup,
294 r#"
295 SELECT oauth2_authorization_grant_id
296 , created_at
297 , cancelled_at
298 , fulfilled_at
299 , exchanged_at
300 , scope
301 , state
302 , redirect_uri
303 , response_mode
304 , nonce
305 , oauth2_client_id
306 , authorization_code
307 , response_type_code
308 , response_type_id_token
309 , code_challenge
310 , code_challenge_method
311 , login_hint
312 , locale
313 , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
314 , oauth2_session_id
315 FROM
316 oauth2_authorization_grants
317
318 WHERE oauth2_authorization_grant_id = $1
319 "#,
320 Uuid::from(id),
321 )
322 .traced()
323 .fetch_optional(&mut *self.conn)
324 .await?;
325
326 let Some(res) = res else { return Ok(None) };
327
328 Ok(Some(res.try_into()?))
329 }
330
331 #[tracing::instrument(
332 name = "db.oauth2_authorization_grant.find_by_code",
333 skip_all,
334 fields(
335 db.query.text,
336 ),
337 err,
338 )]
339 async fn find_by_code(
340 &mut self,
341 code: &str,
342 ) -> Result<Option<AuthorizationGrant>, Self::Error> {
343 let res = sqlx::query_as!(
344 GrantLookup,
345 r#"
346 SELECT oauth2_authorization_grant_id
347 , created_at
348 , cancelled_at
349 , fulfilled_at
350 , exchanged_at
351 , scope
352 , state
353 , redirect_uri
354 , response_mode
355 , nonce
356 , oauth2_client_id
357 , authorization_code
358 , response_type_code
359 , response_type_id_token
360 , code_challenge
361 , code_challenge_method
362 , login_hint
363 , locale
364 , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
365 , oauth2_session_id
366 FROM
367 oauth2_authorization_grants
368
369 WHERE authorization_code = $1
370 "#,
371 code,
372 )
373 .traced()
374 .fetch_optional(&mut *self.conn)
375 .await?;
376
377 let Some(res) = res else { return Ok(None) };
378
379 Ok(Some(res.try_into()?))
380 }
381
382 #[tracing::instrument(
383 name = "db.oauth2_authorization_grant.fulfill",
384 skip_all,
385 fields(
386 db.query.text,
387 %grant.id,
388 client.id = %grant.client_id,
389 %session.id,
390 ),
391 err,
392 )]
393 async fn fulfill(
394 &mut self,
395 clock: &dyn Clock,
396 session: &Session,
397 grant: AuthorizationGrant,
398 ) -> Result<AuthorizationGrant, Self::Error> {
399 let fulfilled_at = clock.now();
400 let res = sqlx::query!(
401 r#"
402 UPDATE oauth2_authorization_grants
403 SET fulfilled_at = $2
404 , oauth2_session_id = $3
405 WHERE oauth2_authorization_grant_id = $1
406 "#,
407 Uuid::from(grant.id),
408 fulfilled_at,
409 Uuid::from(session.id),
410 )
411 .traced()
412 .execute(&mut *self.conn)
413 .await?;
414
415 DatabaseError::ensure_affected_rows(&res, 1)?;
416
417 let grant = grant
419 .fulfill(fulfilled_at, session)
420 .map_err(DatabaseError::to_invalid_operation)?;
421
422 Ok(grant)
423 }
424
425 #[tracing::instrument(
426 name = "db.oauth2_authorization_grant.exchange",
427 skip_all,
428 fields(
429 db.query.text,
430 %grant.id,
431 client.id = %grant.client_id,
432 ),
433 err,
434 )]
435 async fn exchange(
436 &mut self,
437 clock: &dyn Clock,
438 grant: AuthorizationGrant,
439 ) -> Result<AuthorizationGrant, Self::Error> {
440 let exchanged_at = clock.now();
441 let res = sqlx::query!(
442 r#"
443 UPDATE oauth2_authorization_grants
444 SET exchanged_at = $2
445 WHERE oauth2_authorization_grant_id = $1
446 "#,
447 Uuid::from(grant.id),
448 exchanged_at,
449 )
450 .traced()
451 .execute(&mut *self.conn)
452 .await?;
453
454 DatabaseError::ensure_affected_rows(&res, 1)?;
455
456 let grant = grant
457 .exchange(exchanged_at)
458 .map_err(DatabaseError::to_invalid_operation)?;
459
460 Ok(grant)
461 }
462
463 #[tracing::instrument(
464 name = "db.oauth2_authorization_grant.cleanup",
465 skip_all,
466 fields(
467 db.query.text,
468 since = since.map(tracing::field::display),
469 until = %until,
470 limit = limit,
471 ),
472 err,
473 )]
474 async fn cleanup(
475 &mut self,
476 since: Option<Ulid>,
477 until: Ulid,
478 limit: usize,
479 ) -> Result<(usize, Option<Ulid>), Self::Error> {
480 let res = sqlx::query_scalar!(
485 r#"
486 WITH to_delete AS (
487 SELECT oauth2_authorization_grant_id
488 FROM oauth2_authorization_grants
489 WHERE ($1::uuid IS NULL OR oauth2_authorization_grant_id > $1)
490 AND oauth2_authorization_grant_id <= $2
491 ORDER BY oauth2_authorization_grant_id
492 LIMIT $3
493 )
494 DELETE FROM oauth2_authorization_grants
495 USING to_delete
496 WHERE oauth2_authorization_grants.oauth2_authorization_grant_id = to_delete.oauth2_authorization_grant_id
497 RETURNING oauth2_authorization_grants.oauth2_authorization_grant_id
498 "#,
499 since.map(Uuid::from),
500 Uuid::from(until),
501 i64::try_from(limit).unwrap_or(i64::MAX)
502 )
503 .traced()
504 .fetch_all(&mut *self.conn)
505 .await?;
506
507 let count = res.len();
508 let max_id = res.into_iter().max();
509
510 Ok((count, max_id.map(Ulid::from)))
511 }
512}