1use std::ops::{Deref, DerefMut};
8
9use async_trait::async_trait;
10use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
11use mas_storage::{
12 BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
13 RepositoryFactory, RepositoryTransaction,
14 app_session::AppSessionRepository,
15 compat::{
16 CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
17 CompatSsoLoginRepository,
18 },
19 oauth2::{
20 OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
21 OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
22 },
23 policy_data::PolicyDataRepository,
24 queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
25 upstream_oauth2::{
26 UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
27 UpstreamOAuthSessionRepository,
28 },
29 user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
30};
31use sqlx::{PgConnection, PgPool, Postgres, Transaction};
32use tracing::Instrument;
33
34use crate::{
35 DatabaseError,
36 app_session::PgAppSessionRepository,
37 compat::{
38 PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
39 PgCompatSsoLoginRepository,
40 },
41 oauth2::{
42 PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
43 PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
44 PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
45 },
46 policy_data::PgPolicyDataRepository,
47 queue::{
48 job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
49 worker::PgQueueWorkerRepository,
50 },
51 telemetry::DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM,
52 upstream_oauth2::{
53 PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
54 PgUpstreamOAuthSessionRepository,
55 },
56 user::{
57 PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
58 PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRepository,
59 PgUserTermsRepository,
60 },
61};
62
63#[derive(Clone)]
66pub struct PgRepositoryFactory {
67 pool: PgPool,
68}
69
70impl PgRepositoryFactory {
71 #[must_use]
73 pub fn new(pool: PgPool) -> Self {
74 Self { pool }
75 }
76
77 #[must_use]
79 pub fn boxed(self) -> BoxRepositoryFactory {
80 Box::new(self)
81 }
82
83 #[must_use]
85 pub fn pool(&self) -> PgPool {
86 self.pool.clone()
87 }
88}
89
90#[async_trait]
91impl RepositoryFactory for PgRepositoryFactory {
92 async fn create(&self) -> Result<BoxRepository, RepositoryError> {
93 let start = std::time::Instant::now();
94 let repo = PgRepository::from_pool(&self.pool)
95 .await
96 .map_err(RepositoryError::from_error)?
97 .boxed();
98
99 let duration = start.elapsed();
101 let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
102 DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM.record(duration_ms, &[]);
103
104 Ok(repo)
105 }
106}
107
108pub struct PgRepository<C = Transaction<'static, Postgres>> {
111 conn: C,
112}
113
114impl PgRepository {
115 pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
122 let txn = pool.begin().await?;
123 Ok(Self::from_conn(txn))
124 }
125
126 pub fn boxed(self) -> BoxRepository {
128 Box::new(MapErr::new(self, RepositoryError::from_error))
129 }
130}
131
132impl<C> PgRepository<C> {
133 pub fn from_conn(conn: C) -> Self {
136 PgRepository { conn }
137 }
138
139 pub fn into_inner(self) -> C {
141 self.conn
142 }
143}
144
145impl<C> AsRef<C> for PgRepository<C> {
146 fn as_ref(&self) -> &C {
147 &self.conn
148 }
149}
150
151impl<C> AsMut<C> for PgRepository<C> {
152 fn as_mut(&mut self) -> &mut C {
153 &mut self.conn
154 }
155}
156
157impl<C> Deref for PgRepository<C> {
158 type Target = C;
159
160 fn deref(&self) -> &Self::Target {
161 &self.conn
162 }
163}
164
165impl<C> DerefMut for PgRepository<C> {
166 fn deref_mut(&mut self) -> &mut Self::Target {
167 &mut self.conn
168 }
169}
170
171impl Repository<DatabaseError> for PgRepository {}
172
173impl RepositoryTransaction for PgRepository {
174 type Error = DatabaseError;
175
176 fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
177 let span = tracing::info_span!("db.save");
178 self.conn
179 .commit()
180 .map_err(DatabaseError::from)
181 .instrument(span)
182 .boxed()
183 }
184
185 fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
186 let span = tracing::info_span!("db.cancel");
187 self.conn
188 .rollback()
189 .map_err(DatabaseError::from)
190 .instrument(span)
191 .boxed()
192 }
193}
194
195impl<C> RepositoryAccess for PgRepository<C>
196where
197 C: AsMut<PgConnection> + Send,
198{
199 type Error = DatabaseError;
200
201 fn upstream_oauth_link<'c>(
202 &'c mut self,
203 ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
204 Box::new(PgUpstreamOAuthLinkRepository::new(self.conn.as_mut()))
205 }
206
207 fn upstream_oauth_provider<'c>(
208 &'c mut self,
209 ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
210 Box::new(PgUpstreamOAuthProviderRepository::new(self.conn.as_mut()))
211 }
212
213 fn upstream_oauth_session<'c>(
214 &'c mut self,
215 ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
216 Box::new(PgUpstreamOAuthSessionRepository::new(self.conn.as_mut()))
217 }
218
219 fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
220 Box::new(PgUserRepository::new(self.conn.as_mut()))
221 }
222
223 fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
224 Box::new(PgUserEmailRepository::new(self.conn.as_mut()))
225 }
226
227 fn user_password<'c>(
228 &'c mut self,
229 ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
230 Box::new(PgUserPasswordRepository::new(self.conn.as_mut()))
231 }
232
233 fn user_recovery<'c>(
234 &'c mut self,
235 ) -> Box<dyn mas_storage::user::UserRecoveryRepository<Error = Self::Error> + 'c> {
236 Box::new(PgUserRecoveryRepository::new(self.conn.as_mut()))
237 }
238
239 fn user_terms<'c>(
240 &'c mut self,
241 ) -> Box<dyn mas_storage::user::UserTermsRepository<Error = Self::Error> + 'c> {
242 Box::new(PgUserTermsRepository::new(self.conn.as_mut()))
243 }
244
245 fn user_registration<'c>(
246 &'c mut self,
247 ) -> Box<dyn mas_storage::user::UserRegistrationRepository<Error = Self::Error> + 'c> {
248 Box::new(PgUserRegistrationRepository::new(self.conn.as_mut()))
249 }
250
251 fn browser_session<'c>(
252 &'c mut self,
253 ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
254 Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
255 }
256
257 fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
258 Box::new(PgAppSessionRepository::new(self.conn.as_mut()))
259 }
260
261 fn oauth2_client<'c>(
262 &'c mut self,
263 ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
264 Box::new(PgOAuth2ClientRepository::new(self.conn.as_mut()))
265 }
266
267 fn oauth2_authorization_grant<'c>(
268 &'c mut self,
269 ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
270 Box::new(PgOAuth2AuthorizationGrantRepository::new(
271 self.conn.as_mut(),
272 ))
273 }
274
275 fn oauth2_session<'c>(
276 &'c mut self,
277 ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
278 Box::new(PgOAuth2SessionRepository::new(self.conn.as_mut()))
279 }
280
281 fn oauth2_access_token<'c>(
282 &'c mut self,
283 ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
284 Box::new(PgOAuth2AccessTokenRepository::new(self.conn.as_mut()))
285 }
286
287 fn oauth2_refresh_token<'c>(
288 &'c mut self,
289 ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
290 Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut()))
291 }
292
293 fn oauth2_device_code_grant<'c>(
294 &'c mut self,
295 ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
296 Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut()))
297 }
298
299 fn compat_session<'c>(
300 &'c mut self,
301 ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
302 Box::new(PgCompatSessionRepository::new(self.conn.as_mut()))
303 }
304
305 fn compat_sso_login<'c>(
306 &'c mut self,
307 ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
308 Box::new(PgCompatSsoLoginRepository::new(self.conn.as_mut()))
309 }
310
311 fn compat_access_token<'c>(
312 &'c mut self,
313 ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
314 Box::new(PgCompatAccessTokenRepository::new(self.conn.as_mut()))
315 }
316
317 fn compat_refresh_token<'c>(
318 &'c mut self,
319 ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
320 Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut()))
321 }
322
323 fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
324 Box::new(PgQueueWorkerRepository::new(self.conn.as_mut()))
325 }
326
327 fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
328 Box::new(PgQueueJobRepository::new(self.conn.as_mut()))
329 }
330
331 fn queue_schedule<'c>(
332 &'c mut self,
333 ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
334 Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
335 }
336
337 fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
338 Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
339 }
340}