mas_storage_pg/
repository.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::ops::{Deref, DerefMut};
8
9use async_trait::async_trait;
10use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
11use mas_storage::{
12    BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
13    RepositoryFactory, RepositoryTransaction,
14    app_session::AppSessionRepository,
15    compat::{
16        CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
17        CompatSsoLoginRepository,
18    },
19    oauth2::{
20        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
21        OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
22    },
23    policy_data::PolicyDataRepository,
24    queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
25    upstream_oauth2::{
26        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
27        UpstreamOAuthSessionRepository,
28    },
29    user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
30};
31use sqlx::{PgConnection, PgPool, Postgres, Transaction};
32use tracing::Instrument;
33
34use crate::{
35    DatabaseError,
36    app_session::PgAppSessionRepository,
37    compat::{
38        PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
39        PgCompatSsoLoginRepository,
40    },
41    oauth2::{
42        PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
43        PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
44        PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
45    },
46    policy_data::PgPolicyDataRepository,
47    queue::{
48        job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
49        worker::PgQueueWorkerRepository,
50    },
51    telemetry::DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM,
52    upstream_oauth2::{
53        PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
54        PgUpstreamOAuthSessionRepository,
55    },
56    user::{
57        PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
58        PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRepository,
59        PgUserTermsRepository,
60    },
61};
62
63/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL
64/// connection pool.
65#[derive(Clone)]
66pub struct PgRepositoryFactory {
67    pool: PgPool,
68}
69
70impl PgRepositoryFactory {
71    /// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool.
72    #[must_use]
73    pub fn new(pool: PgPool) -> Self {
74        Self { pool }
75    }
76
77    /// Box the factory
78    #[must_use]
79    pub fn boxed(self) -> BoxRepositoryFactory {
80        Box::new(self)
81    }
82
83    /// Get the underlying PostgreSQL connection pool
84    #[must_use]
85    pub fn pool(&self) -> PgPool {
86        self.pool.clone()
87    }
88}
89
90#[async_trait]
91impl RepositoryFactory for PgRepositoryFactory {
92    async fn create(&self) -> Result<BoxRepository, RepositoryError> {
93        let start = std::time::Instant::now();
94        let repo = PgRepository::from_pool(&self.pool)
95            .await
96            .map_err(RepositoryError::from_error)?
97            .boxed();
98
99        // Measure the time it took to create the connection
100        let duration = start.elapsed();
101        let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
102        DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM.record(duration_ms, &[]);
103
104        Ok(repo)
105    }
106}
107
108/// An implementation of the [`Repository`] trait backed by a PostgreSQL
109/// transaction.
110pub struct PgRepository<C = Transaction<'static, Postgres>> {
111    conn: C,
112}
113
114impl PgRepository {
115    /// Create a new [`PgRepository`] from a PostgreSQL connection pool,
116    /// starting a transaction.
117    ///
118    /// # Errors
119    ///
120    /// Returns a [`DatabaseError`] if the transaction could not be started.
121    pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
122        let txn = pool.begin().await?;
123        Ok(Self::from_conn(txn))
124    }
125
126    /// Transform the repository into a type-erased [`BoxRepository`]
127    pub fn boxed(self) -> BoxRepository {
128        Box::new(MapErr::new(self, RepositoryError::from_error))
129    }
130}
131
132impl<C> PgRepository<C> {
133    /// Create a new [`PgRepository`] from an existing PostgreSQL connection
134    /// with a transaction
135    pub fn from_conn(conn: C) -> Self {
136        PgRepository { conn }
137    }
138
139    /// Consume this [`PgRepository`], returning the underlying connection.
140    pub fn into_inner(self) -> C {
141        self.conn
142    }
143}
144
145impl<C> AsRef<C> for PgRepository<C> {
146    fn as_ref(&self) -> &C {
147        &self.conn
148    }
149}
150
151impl<C> AsMut<C> for PgRepository<C> {
152    fn as_mut(&mut self) -> &mut C {
153        &mut self.conn
154    }
155}
156
157impl<C> Deref for PgRepository<C> {
158    type Target = C;
159
160    fn deref(&self) -> &Self::Target {
161        &self.conn
162    }
163}
164
165impl<C> DerefMut for PgRepository<C> {
166    fn deref_mut(&mut self) -> &mut Self::Target {
167        &mut self.conn
168    }
169}
170
171impl Repository<DatabaseError> for PgRepository {}
172
173impl RepositoryTransaction for PgRepository {
174    type Error = DatabaseError;
175
176    fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
177        let span = tracing::info_span!("db.save");
178        self.conn
179            .commit()
180            .map_err(DatabaseError::from)
181            .instrument(span)
182            .boxed()
183    }
184
185    fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
186        let span = tracing::info_span!("db.cancel");
187        self.conn
188            .rollback()
189            .map_err(DatabaseError::from)
190            .instrument(span)
191            .boxed()
192    }
193}
194
195impl<C> RepositoryAccess for PgRepository<C>
196where
197    C: AsMut<PgConnection> + Send,
198{
199    type Error = DatabaseError;
200
201    fn upstream_oauth_link<'c>(
202        &'c mut self,
203    ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
204        Box::new(PgUpstreamOAuthLinkRepository::new(self.conn.as_mut()))
205    }
206
207    fn upstream_oauth_provider<'c>(
208        &'c mut self,
209    ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
210        Box::new(PgUpstreamOAuthProviderRepository::new(self.conn.as_mut()))
211    }
212
213    fn upstream_oauth_session<'c>(
214        &'c mut self,
215    ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
216        Box::new(PgUpstreamOAuthSessionRepository::new(self.conn.as_mut()))
217    }
218
219    fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
220        Box::new(PgUserRepository::new(self.conn.as_mut()))
221    }
222
223    fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
224        Box::new(PgUserEmailRepository::new(self.conn.as_mut()))
225    }
226
227    fn user_password<'c>(
228        &'c mut self,
229    ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
230        Box::new(PgUserPasswordRepository::new(self.conn.as_mut()))
231    }
232
233    fn user_recovery<'c>(
234        &'c mut self,
235    ) -> Box<dyn mas_storage::user::UserRecoveryRepository<Error = Self::Error> + 'c> {
236        Box::new(PgUserRecoveryRepository::new(self.conn.as_mut()))
237    }
238
239    fn user_terms<'c>(
240        &'c mut self,
241    ) -> Box<dyn mas_storage::user::UserTermsRepository<Error = Self::Error> + 'c> {
242        Box::new(PgUserTermsRepository::new(self.conn.as_mut()))
243    }
244
245    fn user_registration<'c>(
246        &'c mut self,
247    ) -> Box<dyn mas_storage::user::UserRegistrationRepository<Error = Self::Error> + 'c> {
248        Box::new(PgUserRegistrationRepository::new(self.conn.as_mut()))
249    }
250
251    fn browser_session<'c>(
252        &'c mut self,
253    ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
254        Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
255    }
256
257    fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
258        Box::new(PgAppSessionRepository::new(self.conn.as_mut()))
259    }
260
261    fn oauth2_client<'c>(
262        &'c mut self,
263    ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
264        Box::new(PgOAuth2ClientRepository::new(self.conn.as_mut()))
265    }
266
267    fn oauth2_authorization_grant<'c>(
268        &'c mut self,
269    ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
270        Box::new(PgOAuth2AuthorizationGrantRepository::new(
271            self.conn.as_mut(),
272        ))
273    }
274
275    fn oauth2_session<'c>(
276        &'c mut self,
277    ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
278        Box::new(PgOAuth2SessionRepository::new(self.conn.as_mut()))
279    }
280
281    fn oauth2_access_token<'c>(
282        &'c mut self,
283    ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
284        Box::new(PgOAuth2AccessTokenRepository::new(self.conn.as_mut()))
285    }
286
287    fn oauth2_refresh_token<'c>(
288        &'c mut self,
289    ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
290        Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut()))
291    }
292
293    fn oauth2_device_code_grant<'c>(
294        &'c mut self,
295    ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
296        Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut()))
297    }
298
299    fn compat_session<'c>(
300        &'c mut self,
301    ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
302        Box::new(PgCompatSessionRepository::new(self.conn.as_mut()))
303    }
304
305    fn compat_sso_login<'c>(
306        &'c mut self,
307    ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
308        Box::new(PgCompatSsoLoginRepository::new(self.conn.as_mut()))
309    }
310
311    fn compat_access_token<'c>(
312        &'c mut self,
313    ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
314        Box::new(PgCompatAccessTokenRepository::new(self.conn.as_mut()))
315    }
316
317    fn compat_refresh_token<'c>(
318        &'c mut self,
319    ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
320        Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut()))
321    }
322
323    fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
324        Box::new(PgQueueWorkerRepository::new(self.conn.as_mut()))
325    }
326
327    fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
328        Box::new(PgQueueJobRepository::new(self.conn.as_mut()))
329    }
330
331    fn queue_schedule<'c>(
332        &'c mut self,
333    ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
334        Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
335    }
336
337    fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
338        Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
339    }
340}