mas_storage/upstream_oauth2/provider.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11 UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
12 UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
13 UpstreamOAuthProviderTokenAuthMethod,
14};
15use mas_iana::jose::JsonWebSignatureAlg;
16use oauth2_types::scope::Scope;
17use rand_core::RngCore;
18use ulid::Ulid;
19use url::Url;
20
21use crate::{Clock, Pagination, pagination::Page, repository_impl};
22
23/// Structure which holds parameters when inserting or updating an upstream
24/// OAuth 2.0 provider
25pub struct UpstreamOAuthProviderParams {
26 /// The OIDC issuer of the provider
27 pub issuer: Option<String>,
28
29 /// A human-readable name for the provider
30 pub human_name: Option<String>,
31
32 /// A brand identifier, e.g. "apple" or "google"
33 pub brand_name: Option<String>,
34
35 /// The scope to request during the authorization flow
36 pub scope: Scope,
37
38 /// The token endpoint authentication method
39 pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
40
41 /// The JWT signing algorithm to use when then `client_secret_jwt` or
42 /// `private_key_jwt` authentication methods are used
43 pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
44
45 /// Expected signature for the JWT payload returned by the token
46 /// authentication endpoint.
47 ///
48 /// Defaults to `RS256`.
49 pub id_token_signed_response_alg: JsonWebSignatureAlg,
50
51 /// Whether to fetch the user profile from the userinfo endpoint,
52 /// or to rely on the data returned in the `id_token` from the
53 /// `token_endpoint`.
54 pub fetch_userinfo: bool,
55
56 /// Expected signature for the JWT payload returned by the userinfo
57 /// endpoint.
58 ///
59 /// If not specified, the response is expected to be an unsigned JSON
60 /// payload. Defaults to `None`.
61 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
62
63 /// The client ID to use when authenticating to the upstream
64 pub client_id: String,
65
66 /// The encrypted client secret to use when authenticating to the upstream
67 pub encrypted_client_secret: Option<String>,
68
69 /// How claims should be imported from the upstream provider
70 pub claims_imports: UpstreamOAuthProviderClaimsImports,
71
72 /// The URL to use as the authorization endpoint. If `None`, the URL will be
73 /// discovered
74 pub authorization_endpoint_override: Option<Url>,
75
76 /// The URL to use as the token endpoint. If `None`, the URL will be
77 /// discovered
78 pub token_endpoint_override: Option<Url>,
79
80 /// The URL to use as the userinfo endpoint. If `None`, the URL will be
81 /// discovered
82 pub userinfo_endpoint_override: Option<Url>,
83
84 /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
85 pub jwks_uri_override: Option<Url>,
86
87 /// How the provider metadata should be discovered
88 pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
89
90 /// How should PKCE be used
91 pub pkce_mode: UpstreamOAuthProviderPkceMode,
92
93 /// What response mode it should ask
94 pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
95
96 /// Additional parameters to include in the authorization request
97 pub additional_authorization_parameters: Vec<(String, String)>,
98
99 /// Whether to forward the login hint to the upstream provider.
100 pub forward_login_hint: bool,
101
102 /// The position of the provider in the UI
103 pub ui_order: i32,
104}
105
106/// Filter parameters for listing upstream OAuth 2.0 providers
107#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
108pub struct UpstreamOAuthProviderFilter<'a> {
109 /// Filter by whether the provider is enabled
110 ///
111 /// If `None`, all providers are returned
112 enabled: Option<bool>,
113
114 _lifetime: PhantomData<&'a ()>,
115}
116
117impl UpstreamOAuthProviderFilter<'_> {
118 /// Create a new [`UpstreamOAuthProviderFilter`] with default values
119 #[must_use]
120 pub fn new() -> Self {
121 Self::default()
122 }
123
124 /// Return only enabled providers
125 #[must_use]
126 pub const fn enabled_only(mut self) -> Self {
127 self.enabled = Some(true);
128 self
129 }
130
131 /// Return only disabled providers
132 #[must_use]
133 pub const fn disabled_only(mut self) -> Self {
134 self.enabled = Some(false);
135 self
136 }
137
138 /// Get the enabled filter
139 ///
140 /// Returns `None` if the filter is not set
141 #[must_use]
142 pub const fn enabled(&self) -> Option<bool> {
143 self.enabled
144 }
145}
146
147/// An [`UpstreamOAuthProviderRepository`] helps interacting with
148/// [`UpstreamOAuthProvider`] saved in the storage backend
149#[async_trait]
150pub trait UpstreamOAuthProviderRepository: Send + Sync {
151 /// The error type returned by the repository
152 type Error;
153
154 /// Lookup an upstream OAuth provider by its ID
155 ///
156 /// Returns `None` if the provider was not found
157 ///
158 /// # Parameters
159 ///
160 /// * `id`: The ID of the provider to lookup
161 ///
162 /// # Errors
163 ///
164 /// Returns [`Self::Error`] if the underlying repository fails
165 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
166
167 /// Add a new upstream OAuth provider
168 ///
169 /// Returns the newly created provider
170 ///
171 /// # Parameters
172 ///
173 /// * `rng`: A random number generator
174 /// * `clock`: The clock used to generate timestamps
175 /// * `params`: The parameters of the provider to add
176 ///
177 /// # Errors
178 ///
179 /// Returns [`Self::Error`] if the underlying repository fails
180 async fn add(
181 &mut self,
182 rng: &mut (dyn RngCore + Send),
183 clock: &dyn Clock,
184 params: UpstreamOAuthProviderParams,
185 ) -> Result<UpstreamOAuthProvider, Self::Error>;
186
187 /// Delete an upstream OAuth provider
188 ///
189 /// # Parameters
190 ///
191 /// * `provider`: The provider to delete
192 ///
193 /// # Errors
194 ///
195 /// Returns [`Self::Error`] if the underlying repository fails
196 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
197 self.delete_by_id(provider.id).await
198 }
199
200 /// Delete an upstream OAuth provider by its ID
201 ///
202 /// # Parameters
203 ///
204 /// * `id`: The ID of the provider to delete
205 ///
206 /// # Errors
207 ///
208 /// Returns [`Self::Error`] if the underlying repository fails
209 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
210
211 /// Insert or update an upstream OAuth provider
212 ///
213 /// # Parameters
214 ///
215 /// * `clock`: The clock used to generate timestamps
216 /// * `id`: The ID of the provider to update
217 /// * `params`: The parameters of the provider to update
218 ///
219 /// # Errors
220 ///
221 /// Returns [`Self::Error`] if the underlying repository fails
222 async fn upsert(
223 &mut self,
224 clock: &dyn Clock,
225 id: Ulid,
226 params: UpstreamOAuthProviderParams,
227 ) -> Result<UpstreamOAuthProvider, Self::Error>;
228
229 /// Disable an upstream OAuth provider
230 ///
231 /// Returns the disabled provider
232 ///
233 /// # Parameters
234 ///
235 /// * `clock`: The clock used to generate timestamps
236 /// * `provider`: The provider to disable
237 ///
238 /// # Errors
239 ///
240 /// Returns [`Self::Error`] if the underlying repository fails
241 async fn disable(
242 &mut self,
243 clock: &dyn Clock,
244 provider: UpstreamOAuthProvider,
245 ) -> Result<UpstreamOAuthProvider, Self::Error>;
246
247 /// List [`UpstreamOAuthProvider`] with the given filter and pagination
248 ///
249 /// # Parameters
250 ///
251 /// * `filter`: The filter to apply
252 /// * `pagination`: The pagination parameters
253 ///
254 /// # Errors
255 ///
256 /// Returns [`Self::Error`] if the underlying repository fails
257 async fn list(
258 &mut self,
259 filter: UpstreamOAuthProviderFilter<'_>,
260 pagination: Pagination,
261 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
262
263 /// Count the number of [`UpstreamOAuthProvider`] with the given filter
264 ///
265 /// # Parameters
266 ///
267 /// * `filter`: The filter to apply
268 ///
269 /// # Errors
270 ///
271 /// Returns [`Self::Error`] if the underlying repository fails
272 async fn count(
273 &mut self,
274 filter: UpstreamOAuthProviderFilter<'_>,
275 ) -> Result<usize, Self::Error>;
276
277 /// Get all enabled upstream OAuth providers
278 ///
279 /// # Errors
280 ///
281 /// Returns [`Self::Error`] if the underlying repository fails
282 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
283}
284
285repository_impl!(UpstreamOAuthProviderRepository:
286 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
287
288 async fn add(
289 &mut self,
290 rng: &mut (dyn RngCore + Send),
291 clock: &dyn Clock,
292 params: UpstreamOAuthProviderParams
293 ) -> Result<UpstreamOAuthProvider, Self::Error>;
294
295 async fn upsert(
296 &mut self,
297 clock: &dyn Clock,
298 id: Ulid,
299 params: UpstreamOAuthProviderParams
300 ) -> Result<UpstreamOAuthProvider, Self::Error>;
301
302 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
303
304 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
305
306 async fn disable(
307 &mut self,
308 clock: &dyn Clock,
309 provider: UpstreamOAuthProvider
310 ) -> Result<UpstreamOAuthProvider, Self::Error>;
311
312 async fn list(
313 &mut self,
314 filter: UpstreamOAuthProviderFilter<'_>,
315 pagination: Pagination
316 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
317
318 async fn count(
319 &mut self,
320 filter: UpstreamOAuthProviderFilter<'_>
321 ) -> Result<usize, Self::Error>;
322
323 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
324);