mas_storage/
repository.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use futures_util::future::BoxFuture;
9use thiserror::Error;
10
11use crate::{
12    app_session::AppSessionRepository,
13    compat::{
14        CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
15        CompatSsoLoginRepository,
16    },
17    oauth2::{
18        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
19        OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
20    },
21    policy_data::PolicyDataRepository,
22    queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
23    upstream_oauth2::{
24        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
25        UpstreamOAuthSessionRepository,
26    },
27    user::{
28        BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
29        UserRecoveryRepository, UserRegistrationRepository, UserRepository, UserTermsRepository,
30    },
31};
32
33/// A [`RepositoryFactory`] is a factory that can create a [`BoxRepository`]
34// XXX(quenting): this could be generic over the repository type, but it's annoying to make it
35// dyn-safe
36#[async_trait]
37pub trait RepositoryFactory {
38    /// Create a new [`BoxRepository`]
39    async fn create(&self) -> Result<BoxRepository, RepositoryError>;
40}
41
42/// A type-erased [`RepositoryFactory`]
43pub type BoxRepositoryFactory = Box<dyn RepositoryFactory + Send + Sync + 'static>;
44
45/// A [`Repository`] helps interacting with the underlying storage backend.
46pub trait Repository<E>:
47    RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send
48where
49    E: std::error::Error + Send + Sync + 'static,
50{
51}
52
53/// An opaque, type-erased error
54#[derive(Debug, Error)]
55#[error(transparent)]
56pub struct RepositoryError {
57    source: Box<dyn std::error::Error + Send + Sync + 'static>,
58}
59
60impl RepositoryError {
61    /// Construct a [`RepositoryError`] from any error kind
62    pub fn from_error<E>(value: E) -> Self
63    where
64        E: std::error::Error + Send + Sync + 'static,
65    {
66        Self {
67            source: Box::new(value),
68        }
69    }
70}
71
72/// A type-erased [`Repository`]
73pub type BoxRepository = Box<dyn Repository<RepositoryError> + Send + Sync + 'static>;
74
75/// A [`RepositoryTransaction`] can be saved or cancelled, after a series
76/// of operations.
77pub trait RepositoryTransaction {
78    /// The error type used by the [`Self::save`] and [`Self::cancel`] functions
79    type Error;
80
81    /// Commit the transaction
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the underlying storage backend failed to commit the
86    /// transaction.
87    fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
88
89    /// Rollback the transaction
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if the underlying storage backend failed to rollback
94    /// the transaction.
95    fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
96}
97
98/// Access the various repositories the backend implements.
99///
100/// All the methods return a boxed trait object, which can be used to access a
101/// particular repository. The lifetime of the returned object is bound to the
102/// lifetime of the whole repository, so that only one mutable reference to the
103/// repository is used at a time.
104///
105/// When adding a new repository, you should add a new method to this trait, and
106/// update the implementations for [`crate::MapErr`] and [`Box<R>`] below.
107///
108/// Note: this used to have generic associated types to avoid boxing all the
109/// repository traits, but that was removed because it made almost impossible to
110/// box the trait object. This might be a shortcoming of the initial
111/// implementation of generic associated types, and might be fixed in the
112/// future.
113pub trait RepositoryAccess: Send {
114    /// The backend-specific error type used by each repository.
115    type Error: std::error::Error + Send + Sync + 'static;
116
117    /// Get an [`UpstreamOAuthLinkRepository`]
118    fn upstream_oauth_link<'c>(
119        &'c mut self,
120    ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
121
122    /// Get an [`UpstreamOAuthProviderRepository`]
123    fn upstream_oauth_provider<'c>(
124        &'c mut self,
125    ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c>;
126
127    /// Get an [`UpstreamOAuthSessionRepository`]
128    fn upstream_oauth_session<'c>(
129        &'c mut self,
130    ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c>;
131
132    /// Get an [`UserRepository`]
133    fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c>;
134
135    /// Get an [`UserEmailRepository`]
136    fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c>;
137
138    /// Get an [`UserPasswordRepository`]
139    fn user_password<'c>(&'c mut self)
140    -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c>;
141
142    /// Get an [`UserRecoveryRepository`]
143    fn user_recovery<'c>(&'c mut self)
144    -> Box<dyn UserRecoveryRepository<Error = Self::Error> + 'c>;
145
146    /// Get an [`UserRegistrationRepository`]
147    fn user_registration<'c>(
148        &'c mut self,
149    ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c>;
150
151    /// Get an [`UserTermsRepository`]
152    fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c>;
153
154    /// Get a [`BrowserSessionRepository`]
155    fn browser_session<'c>(
156        &'c mut self,
157    ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
158
159    /// Get a [`AppSessionRepository`]
160    fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c>;
161
162    /// Get an [`OAuth2ClientRepository`]
163    fn oauth2_client<'c>(&'c mut self)
164    -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
165
166    /// Get an [`OAuth2AuthorizationGrantRepository`]
167    fn oauth2_authorization_grant<'c>(
168        &'c mut self,
169    ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c>;
170
171    /// Get an [`OAuth2SessionRepository`]
172    fn oauth2_session<'c>(
173        &'c mut self,
174    ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c>;
175
176    /// Get an [`OAuth2AccessTokenRepository`]
177    fn oauth2_access_token<'c>(
178        &'c mut self,
179    ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c>;
180
181    /// Get an [`OAuth2RefreshTokenRepository`]
182    fn oauth2_refresh_token<'c>(
183        &'c mut self,
184    ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c>;
185
186    /// Get an [`OAuth2DeviceCodeGrantRepository`]
187    fn oauth2_device_code_grant<'c>(
188        &'c mut self,
189    ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c>;
190
191    /// Get a [`CompatSessionRepository`]
192    fn compat_session<'c>(
193        &'c mut self,
194    ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c>;
195
196    /// Get a [`CompatSsoLoginRepository`]
197    fn compat_sso_login<'c>(
198        &'c mut self,
199    ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c>;
200
201    /// Get a [`CompatAccessTokenRepository`]
202    fn compat_access_token<'c>(
203        &'c mut self,
204    ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c>;
205
206    /// Get a [`CompatRefreshTokenRepository`]
207    fn compat_refresh_token<'c>(
208        &'c mut self,
209    ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
210
211    /// Get a [`QueueWorkerRepository`]
212    fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c>;
213
214    /// Get a [`QueueJobRepository`]
215    fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c>;
216
217    /// Get a [`QueueScheduleRepository`]
218    fn queue_schedule<'c>(
219        &'c mut self,
220    ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c>;
221
222    /// Get a [`PolicyDataRepository`]
223    fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c>;
224}
225
226/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
227/// [`Repository`] for the [`crate::MapErr`] wrapper and [`Box<R>`]
228mod impls {
229    use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
230
231    use super::RepositoryAccess;
232    use crate::{
233        MapErr, Repository, RepositoryTransaction,
234        app_session::AppSessionRepository,
235        compat::{
236            CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
237            CompatSsoLoginRepository,
238        },
239        oauth2::{
240            OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
241            OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository,
242            OAuth2SessionRepository,
243        },
244        policy_data::PolicyDataRepository,
245        queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
246        upstream_oauth2::{
247            UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
248            UpstreamOAuthSessionRepository,
249        },
250        user::{
251            BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
252            UserRegistrationRepository, UserRepository, UserTermsRepository,
253        },
254    };
255
256    // --- Repository ---
257    impl<R, F, E1, E2> Repository<E2> for MapErr<R, F>
258    where
259        R: Repository<E1> + RepositoryAccess<Error = E1> + RepositoryTransaction<Error = E1>,
260        F: FnMut(E1) -> E2 + Send + Sync + 'static,
261        E1: std::error::Error + Send + Sync + 'static,
262        E2: std::error::Error + Send + Sync + 'static,
263    {
264    }
265
266    // --- RepositoryTransaction --
267    impl<R, F, E> RepositoryTransaction for MapErr<R, F>
268    where
269        R: RepositoryTransaction,
270        R::Error: 'static,
271        F: FnMut(R::Error) -> E + Send + Sync + 'static,
272        E: std::error::Error,
273    {
274        type Error = E;
275
276        fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
277            Box::new(self.inner).save().map_err(self.mapper).boxed()
278        }
279
280        fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
281            Box::new(self.inner).cancel().map_err(self.mapper).boxed()
282        }
283    }
284
285    // --- RepositoryAccess --
286    impl<R, F, E> RepositoryAccess for MapErr<R, F>
287    where
288        R: RepositoryAccess,
289        R::Error: 'static,
290        F: FnMut(R::Error) -> E + Send + Sync + 'static,
291        E: std::error::Error + Send + Sync + 'static,
292    {
293        type Error = E;
294
295        fn upstream_oauth_link<'c>(
296            &'c mut self,
297        ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
298            Box::new(MapErr::new(
299                self.inner.upstream_oauth_link(),
300                &mut self.mapper,
301            ))
302        }
303
304        fn upstream_oauth_provider<'c>(
305            &'c mut self,
306        ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
307            Box::new(MapErr::new(
308                self.inner.upstream_oauth_provider(),
309                &mut self.mapper,
310            ))
311        }
312
313        fn upstream_oauth_session<'c>(
314            &'c mut self,
315        ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
316            Box::new(MapErr::new(
317                self.inner.upstream_oauth_session(),
318                &mut self.mapper,
319            ))
320        }
321
322        fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
323            Box::new(MapErr::new(self.inner.user(), &mut self.mapper))
324        }
325
326        fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
327            Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper))
328        }
329
330        fn user_password<'c>(
331            &'c mut self,
332        ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
333            Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper))
334        }
335
336        fn user_recovery<'c>(
337            &'c mut self,
338        ) -> Box<dyn crate::user::UserRecoveryRepository<Error = Self::Error> + 'c> {
339            Box::new(MapErr::new(self.inner.user_recovery(), &mut self.mapper))
340        }
341
342        fn user_registration<'c>(
343            &'c mut self,
344        ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c> {
345            Box::new(MapErr::new(
346                self.inner.user_registration(),
347                &mut self.mapper,
348            ))
349        }
350
351        fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c> {
352            Box::new(MapErr::new(self.inner.user_terms(), &mut self.mapper))
353        }
354
355        fn browser_session<'c>(
356            &'c mut self,
357        ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
358            Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper))
359        }
360
361        fn app_session<'c>(
362            &'c mut self,
363        ) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
364            Box::new(MapErr::new(self.inner.app_session(), &mut self.mapper))
365        }
366
367        fn oauth2_client<'c>(
368            &'c mut self,
369        ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
370            Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper))
371        }
372
373        fn oauth2_authorization_grant<'c>(
374            &'c mut self,
375        ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
376            Box::new(MapErr::new(
377                self.inner.oauth2_authorization_grant(),
378                &mut self.mapper,
379            ))
380        }
381
382        fn oauth2_session<'c>(
383            &'c mut self,
384        ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
385            Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper))
386        }
387
388        fn oauth2_access_token<'c>(
389            &'c mut self,
390        ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
391            Box::new(MapErr::new(
392                self.inner.oauth2_access_token(),
393                &mut self.mapper,
394            ))
395        }
396
397        fn oauth2_refresh_token<'c>(
398            &'c mut self,
399        ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
400            Box::new(MapErr::new(
401                self.inner.oauth2_refresh_token(),
402                &mut self.mapper,
403            ))
404        }
405
406        fn oauth2_device_code_grant<'c>(
407            &'c mut self,
408        ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
409            Box::new(MapErr::new(
410                self.inner.oauth2_device_code_grant(),
411                &mut self.mapper,
412            ))
413        }
414
415        fn compat_session<'c>(
416            &'c mut self,
417        ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
418            Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper))
419        }
420
421        fn compat_sso_login<'c>(
422            &'c mut self,
423        ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
424            Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper))
425        }
426
427        fn compat_access_token<'c>(
428            &'c mut self,
429        ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
430            Box::new(MapErr::new(
431                self.inner.compat_access_token(),
432                &mut self.mapper,
433            ))
434        }
435
436        fn compat_refresh_token<'c>(
437            &'c mut self,
438        ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
439            Box::new(MapErr::new(
440                self.inner.compat_refresh_token(),
441                &mut self.mapper,
442            ))
443        }
444
445        fn queue_worker<'c>(
446            &'c mut self,
447        ) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
448            Box::new(MapErr::new(self.inner.queue_worker(), &mut self.mapper))
449        }
450
451        fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
452            Box::new(MapErr::new(self.inner.queue_job(), &mut self.mapper))
453        }
454
455        fn queue_schedule<'c>(
456            &'c mut self,
457        ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
458            Box::new(MapErr::new(self.inner.queue_schedule(), &mut self.mapper))
459        }
460
461        fn policy_data<'c>(
462            &'c mut self,
463        ) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
464            Box::new(MapErr::new(self.inner.policy_data(), &mut self.mapper))
465        }
466    }
467
468    impl<R: RepositoryAccess + ?Sized> RepositoryAccess for Box<R> {
469        type Error = R::Error;
470
471        fn upstream_oauth_link<'c>(
472            &'c mut self,
473        ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
474            (**self).upstream_oauth_link()
475        }
476
477        fn upstream_oauth_provider<'c>(
478            &'c mut self,
479        ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
480            (**self).upstream_oauth_provider()
481        }
482
483        fn upstream_oauth_session<'c>(
484            &'c mut self,
485        ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
486            (**self).upstream_oauth_session()
487        }
488
489        fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
490            (**self).user()
491        }
492
493        fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
494            (**self).user_email()
495        }
496
497        fn user_password<'c>(
498            &'c mut self,
499        ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
500            (**self).user_password()
501        }
502
503        fn user_recovery<'c>(
504            &'c mut self,
505        ) -> Box<dyn crate::user::UserRecoveryRepository<Error = Self::Error> + 'c> {
506            (**self).user_recovery()
507        }
508
509        fn user_registration<'c>(
510            &'c mut self,
511        ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c> {
512            (**self).user_registration()
513        }
514
515        fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c> {
516            (**self).user_terms()
517        }
518
519        fn browser_session<'c>(
520            &'c mut self,
521        ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
522            (**self).browser_session()
523        }
524
525        fn app_session<'c>(
526            &'c mut self,
527        ) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
528            (**self).app_session()
529        }
530
531        fn oauth2_client<'c>(
532            &'c mut self,
533        ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
534            (**self).oauth2_client()
535        }
536
537        fn oauth2_authorization_grant<'c>(
538            &'c mut self,
539        ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
540            (**self).oauth2_authorization_grant()
541        }
542
543        fn oauth2_session<'c>(
544            &'c mut self,
545        ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
546            (**self).oauth2_session()
547        }
548
549        fn oauth2_access_token<'c>(
550            &'c mut self,
551        ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
552            (**self).oauth2_access_token()
553        }
554
555        fn oauth2_refresh_token<'c>(
556            &'c mut self,
557        ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
558            (**self).oauth2_refresh_token()
559        }
560
561        fn oauth2_device_code_grant<'c>(
562            &'c mut self,
563        ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
564            (**self).oauth2_device_code_grant()
565        }
566
567        fn compat_session<'c>(
568            &'c mut self,
569        ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
570            (**self).compat_session()
571        }
572
573        fn compat_sso_login<'c>(
574            &'c mut self,
575        ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
576            (**self).compat_sso_login()
577        }
578
579        fn compat_access_token<'c>(
580            &'c mut self,
581        ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
582            (**self).compat_access_token()
583        }
584
585        fn compat_refresh_token<'c>(
586            &'c mut self,
587        ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
588            (**self).compat_refresh_token()
589        }
590
591        fn queue_worker<'c>(
592            &'c mut self,
593        ) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
594            (**self).queue_worker()
595        }
596
597        fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
598            (**self).queue_job()
599        }
600
601        fn queue_schedule<'c>(
602            &'c mut self,
603        ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
604            (**self).queue_schedule()
605        }
606
607        fn policy_data<'c>(
608            &'c mut self,
609        ) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
610            (**self).policy_data()
611        }
612    }
613}