1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::marker::PhantomData;
use async_trait::async_trait;
use mas_data_model::{
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderPkceMode,
};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use rand_core::RngCore;
use ulid::Ulid;
use url::Url;
use crate::{pagination::Page, repository_impl, Clock, Pagination};
/// Structure which holds parameters when inserting or updating an upstream
/// OAuth 2.0 provider
pub struct UpstreamOAuthProviderParams {
/// The OIDC issuer of the provider
pub issuer: String,
/// A human-readable name for the provider
pub human_name: Option<String>,
/// A brand identifier, e.g. "apple" or "google"
pub brand_name: Option<String>,
/// The scope to request during the authorization flow
pub scope: Scope,
/// The token endpoint authentication method
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
/// The JWT signing algorithm to use when then `client_secret_jwt` or
/// `private_key_jwt` authentication methods are used
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
/// The client ID to use when authenticating to the upstream
pub client_id: String,
/// The encrypted client secret to use when authenticating to the upstream
pub encrypted_client_secret: Option<String>,
/// How claims should be imported from the upstream provider
pub claims_imports: UpstreamOAuthProviderClaimsImports,
/// The URL to use as the authorization endpoint. If `None`, the URL will be
/// discovered
pub authorization_endpoint_override: Option<Url>,
/// The URL to use as the token endpoint. If `None`, the URL will be
/// discovered
pub token_endpoint_override: Option<Url>,
/// The URL to use when fetching JWKS. If `None`, the URL will be discovered
pub jwks_uri_override: Option<Url>,
/// How the provider metadata should be discovered
pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
/// How should PKCE be used
pub pkce_mode: UpstreamOAuthProviderPkceMode,
/// Additional parameters to include in the authorization request
pub additional_authorization_parameters: Vec<(String, String)>,
}
/// Filter parameters for listing upstream OAuth 2.0 providers
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UpstreamOAuthProviderFilter<'a> {
/// Filter by whether the provider is enabled
///
/// If `None`, all providers are returned
enabled: Option<bool>,
_lifetime: PhantomData<&'a ()>,
}
impl<'a> UpstreamOAuthProviderFilter<'a> {
/// Create a new [`UpstreamOAuthProviderFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Return only enabled providers
#[must_use]
pub const fn enabled_only(mut self) -> Self {
self.enabled = Some(true);
self
}
/// Return only disabled providers
#[must_use]
pub const fn disabled_only(mut self) -> Self {
self.enabled = Some(false);
self
}
/// Get the enabled filter
///
/// Returns `None` if the filter is not set
#[must_use]
pub const fn enabled(&self) -> Option<bool> {
self.enabled
}
}
/// An [`UpstreamOAuthProviderRepository`] helps interacting with
/// [`UpstreamOAuthProvider`] saved in the storage backend
#[async_trait]
pub trait UpstreamOAuthProviderRepository: Send + Sync {
/// The error type returned by the repository
type Error;
/// Lookup an upstream OAuth provider by its ID
///
/// Returns `None` if the provider was not found
///
/// # Parameters
///
/// * `id`: The ID of the provider to lookup
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
/// Add a new upstream OAuth provider
///
/// Returns the newly created provider
///
/// # Parameters
///
/// * `rng`: A random number generator
/// * `clock`: The clock used to generate timestamps
/// * `params`: The parameters of the provider to add
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
params: UpstreamOAuthProviderParams,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Delete an upstream OAuth provider
///
/// # Parameters
///
/// * `provider`: The provider to delete
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
self.delete_by_id(provider.id).await
}
/// Delete an upstream OAuth provider by its ID
///
/// # Parameters
///
/// * `id`: The ID of the provider to delete
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
/// Insert or update an upstream OAuth provider
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `id`: The ID of the provider to update
/// * `params`: The parameters of the provider to update
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn upsert(
&mut self,
clock: &dyn Clock,
id: Ulid,
params: UpstreamOAuthProviderParams,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Disable an upstream OAuth provider
///
/// Returns the disabled provider
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `provider`: The provider to disable
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn disable(
&mut self,
clock: &dyn Clock,
provider: UpstreamOAuthProvider,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
///
/// # Parameters
///
/// * `filter`: The filter to apply
/// * `pagination`: The pagination parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn list(
&mut self,
filter: UpstreamOAuthProviderFilter<'_>,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
/// Count the number of [`UpstreamOAuthProvider`] with the given filter
///
/// # Parameters
///
/// * `filter`: The filter to apply
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(
&mut self,
filter: UpstreamOAuthProviderFilter<'_>,
) -> Result<usize, Self::Error>;
/// Get all enabled upstream OAuth providers
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
}
repository_impl!(UpstreamOAuthProviderRepository:
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
params: UpstreamOAuthProviderParams
) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn upsert(
&mut self,
clock: &dyn Clock,
id: Ulid,
params: UpstreamOAuthProviderParams
) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
async fn disable(
&mut self,
clock: &dyn Clock,
provider: UpstreamOAuthProvider
) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn list(
&mut self,
filter: UpstreamOAuthProviderFilter<'_>,
pagination: Pagination
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
async fn count(
&mut self,
filter: UpstreamOAuthProviderFilter<'_>
) -> Result<usize, Self::Error>;
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
);