1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::marker::PhantomData;

use async_trait::async_trait;
use mas_data_model::{
    UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
    UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
    UpstreamOAuthProviderTokenAuthMethod,
};
use mas_iana::jose::JsonWebSignatureAlg;
use oauth2_types::scope::Scope;
use rand_core::RngCore;
use ulid::Ulid;
use url::Url;

use crate::{pagination::Page, repository_impl, Clock, Pagination};

/// Structure which holds parameters when inserting or updating an upstream
/// OAuth 2.0 provider
pub struct UpstreamOAuthProviderParams {
    /// The OIDC issuer of the provider
    pub issuer: String,

    /// A human-readable name for the provider
    pub human_name: Option<String>,

    /// A brand identifier, e.g. "apple" or "google"
    pub brand_name: Option<String>,

    /// The scope to request during the authorization flow
    pub scope: Scope,

    /// The token endpoint authentication method
    pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,

    /// The JWT signing algorithm to use when then `client_secret_jwt` or
    /// `private_key_jwt` authentication methods are used
    pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,

    /// The client ID to use when authenticating to the upstream
    pub client_id: String,

    /// The encrypted client secret to use when authenticating to the upstream
    pub encrypted_client_secret: Option<String>,

    /// How claims should be imported from the upstream provider
    pub claims_imports: UpstreamOAuthProviderClaimsImports,

    /// The URL to use as the authorization endpoint. If `None`, the URL will be
    /// discovered
    pub authorization_endpoint_override: Option<Url>,

    /// The URL to use as the token endpoint. If `None`, the URL will be
    /// discovered
    pub token_endpoint_override: Option<Url>,

    /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
    pub jwks_uri_override: Option<Url>,

    /// How the provider metadata should be discovered
    pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,

    /// How should PKCE be used
    pub pkce_mode: UpstreamOAuthProviderPkceMode,

    /// What response mode it should ask
    pub response_mode: UpstreamOAuthProviderResponseMode,

    /// Additional parameters to include in the authorization request
    pub additional_authorization_parameters: Vec<(String, String)>,
}

/// Filter parameters for listing upstream OAuth 2.0 providers
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UpstreamOAuthProviderFilter<'a> {
    /// Filter by whether the provider is enabled
    ///
    /// If `None`, all providers are returned
    enabled: Option<bool>,

    _lifetime: PhantomData<&'a ()>,
}

impl<'a> UpstreamOAuthProviderFilter<'a> {
    /// Create a new [`UpstreamOAuthProviderFilter`] with default values
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    /// Return only enabled providers
    #[must_use]
    pub const fn enabled_only(mut self) -> Self {
        self.enabled = Some(true);
        self
    }

    /// Return only disabled providers
    #[must_use]
    pub const fn disabled_only(mut self) -> Self {
        self.enabled = Some(false);
        self
    }

    /// Get the enabled filter
    ///
    /// Returns `None` if the filter is not set
    #[must_use]
    pub const fn enabled(&self) -> Option<bool> {
        self.enabled
    }
}

/// An [`UpstreamOAuthProviderRepository`] helps interacting with
/// [`UpstreamOAuthProvider`] saved in the storage backend
#[async_trait]
pub trait UpstreamOAuthProviderRepository: Send + Sync {
    /// The error type returned by the repository
    type Error;

    /// Lookup an upstream OAuth provider by its ID
    ///
    /// Returns `None` if the provider was not found
    ///
    /// # Parameters
    ///
    /// * `id`: The ID of the provider to lookup
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;

    /// Add a new upstream OAuth provider
    ///
    /// Returns the newly created provider
    ///
    /// # Parameters
    ///
    /// * `rng`: A random number generator
    /// * `clock`: The clock used to generate timestamps
    /// * `params`: The parameters of the provider to add
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn add(
        &mut self,
        rng: &mut (dyn RngCore + Send),
        clock: &dyn Clock,
        params: UpstreamOAuthProviderParams,
    ) -> Result<UpstreamOAuthProvider, Self::Error>;

    /// Delete an upstream OAuth provider
    ///
    /// # Parameters
    ///
    /// * `provider`: The provider to delete
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
        self.delete_by_id(provider.id).await
    }

    /// Delete an upstream OAuth provider by its ID
    ///
    /// # Parameters
    ///
    /// * `id`: The ID of the provider to delete
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;

    /// Insert or update an upstream OAuth provider
    ///
    /// # Parameters
    ///
    /// * `clock`: The clock used to generate timestamps
    /// * `id`: The ID of the provider to update
    /// * `params`: The parameters of the provider to update
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn upsert(
        &mut self,
        clock: &dyn Clock,
        id: Ulid,
        params: UpstreamOAuthProviderParams,
    ) -> Result<UpstreamOAuthProvider, Self::Error>;

    /// Disable an upstream OAuth provider
    ///
    /// Returns the disabled provider
    ///
    /// # Parameters
    ///
    /// * `clock`: The clock used to generate timestamps
    /// * `provider`: The provider to disable
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn disable(
        &mut self,
        clock: &dyn Clock,
        provider: UpstreamOAuthProvider,
    ) -> Result<UpstreamOAuthProvider, Self::Error>;

    /// List [`UpstreamOAuthProvider`] with the given filter and pagination
    ///
    /// # Parameters
    ///
    /// * `filter`: The filter to apply
    /// * `pagination`: The pagination parameters
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn list(
        &mut self,
        filter: UpstreamOAuthProviderFilter<'_>,
        pagination: Pagination,
    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;

    /// Count the number of [`UpstreamOAuthProvider`] with the given filter
    ///
    /// # Parameters
    ///
    /// * `filter`: The filter to apply
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn count(
        &mut self,
        filter: UpstreamOAuthProviderFilter<'_>,
    ) -> Result<usize, Self::Error>;

    /// Get all enabled upstream OAuth providers
    ///
    /// # Errors
    ///
    /// Returns [`Self::Error`] if the underlying repository fails
    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
}

repository_impl!(UpstreamOAuthProviderRepository:
    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;

    async fn add(
        &mut self,
        rng: &mut (dyn RngCore + Send),
        clock: &dyn Clock,
        params: UpstreamOAuthProviderParams
    ) -> Result<UpstreamOAuthProvider, Self::Error>;

    async fn upsert(
        &mut self,
        clock: &dyn Clock,
        id: Ulid,
        params: UpstreamOAuthProviderParams
    ) -> Result<UpstreamOAuthProvider, Self::Error>;

    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;

    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;

    async fn disable(
        &mut self,
        clock: &dyn Clock,
        provider: UpstreamOAuthProvider
    ) -> Result<UpstreamOAuthProvider, Self::Error>;

    async fn list(
        &mut self,
        filter: UpstreamOAuthProviderFilter<'_>,
        pagination: Pagination
    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;

    async fn count(
        &mut self,
        filter: UpstreamOAuthProviderFilter<'_>
    ) -> Result<usize, Self::Error>;

    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
);