1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 Clock, CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
11 UlidExt as _,
12};
13use mas_storage::compat::CompatRefreshTokenRepository;
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, tracing::ExecuteExt};
20
21pub struct PgCompatRefreshTokenRepository<'c> {
24 conn: &'c mut PgConnection,
25}
26
27impl<'c> PgCompatRefreshTokenRepository<'c> {
28 pub fn new(conn: &'c mut PgConnection) -> Self {
31 Self { conn }
32 }
33}
34
35struct CompatRefreshTokenLookup {
36 compat_refresh_token_id: Uuid,
37 refresh_token: String,
38 created_at: DateTime<Utc>,
39 consumed_at: Option<DateTime<Utc>>,
40 compat_access_token_id: Uuid,
41 compat_session_id: Uuid,
42}
43
44impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
45 fn from(value: CompatRefreshTokenLookup) -> Self {
46 let state = match value.consumed_at {
47 Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
48 None => CompatRefreshTokenState::Valid,
49 };
50
51 Self {
52 id: value.compat_refresh_token_id.into(),
53 state,
54 session_id: value.compat_session_id.into(),
55 token: value.refresh_token,
56 created_at: value.created_at,
57 access_token_id: value.compat_access_token_id.into(),
58 }
59 }
60}
61
62#[async_trait]
63impl CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'_> {
64 type Error = DatabaseError;
65
66 #[tracing::instrument(
67 name = "db.compat_refresh_token.lookup",
68 skip_all,
69 fields(
70 db.query.text,
71 compat_refresh_token.id = %id,
72 ),
73 err,
74 )]
75 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
76 let res = sqlx::query_as!(
77 CompatRefreshTokenLookup,
78 r#"
79 SELECT compat_refresh_token_id
80 , refresh_token
81 , created_at
82 , consumed_at
83 , compat_session_id
84 , compat_access_token_id
85
86 FROM compat_refresh_tokens
87
88 WHERE compat_refresh_token_id = $1
89 "#,
90 Uuid::from(id),
91 )
92 .traced()
93 .fetch_optional(&mut *self.conn)
94 .await?;
95
96 let Some(res) = res else { return Ok(None) };
97
98 Ok(Some(res.into()))
99 }
100
101 #[tracing::instrument(
102 name = "db.compat_refresh_token.find_by_token",
103 skip_all,
104 fields(
105 db.query.text,
106 ),
107 err,
108 )]
109 async fn find_by_token(
110 &mut self,
111 refresh_token: &str,
112 ) -> Result<Option<CompatRefreshToken>, Self::Error> {
113 let res = sqlx::query_as!(
114 CompatRefreshTokenLookup,
115 r#"
116 SELECT compat_refresh_token_id
117 , refresh_token
118 , created_at
119 , consumed_at
120 , compat_session_id
121 , compat_access_token_id
122
123 FROM compat_refresh_tokens
124
125 WHERE refresh_token = $1
126 "#,
127 refresh_token,
128 )
129 .traced()
130 .fetch_optional(&mut *self.conn)
131 .await?;
132
133 let Some(res) = res else { return Ok(None) };
134
135 Ok(Some(res.into()))
136 }
137
138 #[tracing::instrument(
139 name = "db.compat_refresh_token.add",
140 skip_all,
141 fields(
142 db.query.text,
143 compat_refresh_token.id,
144 %compat_session.id,
145 user.id = %compat_session.user_id,
146 ),
147 err,
148 )]
149 async fn add(
150 &mut self,
151 rng: &mut (dyn RngCore + Send),
152 clock: &dyn Clock,
153 compat_session: &CompatSession,
154 compat_access_token: &CompatAccessToken,
155 token: String,
156 ) -> Result<CompatRefreshToken, Self::Error> {
157 let created_at = clock.now();
158 let id = Ulid::from_datetime_with_rng(created_at, rng);
159 tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
160
161 sqlx::query!(
162 r#"
163 INSERT INTO compat_refresh_tokens
164 (compat_refresh_token_id, compat_session_id,
165 compat_access_token_id, refresh_token, created_at)
166 VALUES ($1, $2, $3, $4, $5)
167 "#,
168 Uuid::from(id),
169 Uuid::from(compat_session.id),
170 Uuid::from(compat_access_token.id),
171 token,
172 created_at,
173 )
174 .traced()
175 .execute(&mut *self.conn)
176 .await?;
177
178 Ok(CompatRefreshToken {
179 id,
180 state: CompatRefreshTokenState::default(),
181 session_id: compat_session.id,
182 access_token_id: compat_access_token.id,
183 token,
184 created_at,
185 })
186 }
187
188 #[tracing::instrument(
189 name = "db.compat_refresh_token.consume_and_replace",
190 skip_all,
191 fields(
192 db.query.text,
193 %compat_refresh_token.id,
194 %successor_compat_refresh_token.id,
195 compat_session.id = %compat_refresh_token.session_id,
196 ),
197 err,
198 )]
199 async fn consume_and_replace(
200 &mut self,
201 clock: &dyn Clock,
202 compat_refresh_token: CompatRefreshToken,
203 successor_compat_refresh_token: &CompatRefreshToken,
204 ) -> Result<CompatRefreshToken, Self::Error> {
205 if compat_refresh_token.session_id != successor_compat_refresh_token.session_id {
206 return Err(DatabaseError::invalid_operation());
207 }
208
209 let consumed_at = clock.now();
210 let res = sqlx::query!(
211 r#"
212 UPDATE compat_refresh_tokens
213 SET consumed_at = $2
214 WHERE compat_session_id = $1
215 AND consumed_at IS NULL
216 AND compat_refresh_token_id <> $3
217 "#,
218 Uuid::from(compat_refresh_token.session_id),
219 consumed_at,
220 Uuid::from(successor_compat_refresh_token.id),
221 )
222 .traced()
223 .execute(&mut *self.conn)
224 .await?;
225
226 if res.rows_affected() == 0 {
230 return Err(DatabaseError::RowsAffected {
231 expected: 1,
232 actual: 0,
233 });
234 }
235
236 let compat_refresh_token = compat_refresh_token
237 .consume(consumed_at)
238 .map_err(DatabaseError::to_invalid_operation)?;
239
240 Ok(compat_refresh_token)
241 }
242}