1use std::{collections::HashMap, sync::Arc};
8
9use mas_context::LogContext;
10use mas_data_model::{
11 UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
12};
13use mas_iana::oauth::PkceCodeChallengeMethod;
14use mas_oidc_client::error::DiscoveryError;
15use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
16use oauth2_types::oidc::VerifiedProviderMetadata;
17use tokio::sync::RwLock;
18use url::Url;
19
20pub struct LazyProviderInfos<'a> {
23 cache: &'a MetadataCache,
24 provider: &'a UpstreamOAuthProvider,
25 client: &'a reqwest::Client,
26 loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
27}
28
29impl<'a> LazyProviderInfos<'a> {
30 pub fn new(
31 cache: &'a MetadataCache,
32 provider: &'a UpstreamOAuthProvider,
33 client: &'a reqwest::Client,
34 ) -> Self {
35 Self {
36 cache,
37 provider,
38 client,
39 loaded_metadata: None,
40 }
41 }
42
43 pub async fn maybe_discover(
46 &mut self,
47 ) -> Result<Option<&VerifiedProviderMetadata>, DiscoveryError> {
48 match self.load().await {
49 Ok(metadata) => Ok(Some(metadata)),
50 Err(DiscoveryError::Disabled) => Ok(None),
51 Err(e) => Err(e),
52 }
53 }
54
55 async fn load(&mut self) -> Result<&VerifiedProviderMetadata, DiscoveryError> {
56 if self.loaded_metadata.is_none() {
57 let verify = match self.provider.discovery_mode {
58 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
59 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
60 UpstreamOAuthProviderDiscoveryMode::Disabled => {
61 return Err(DiscoveryError::Disabled);
62 }
63 };
64
65 let Some(issuer) = &self.provider.issuer else {
66 return Err(DiscoveryError::MissingIssuer);
67 };
68
69 let metadata = self.cache.get(self.client, issuer, verify).await?;
70
71 self.loaded_metadata = Some(metadata);
72 }
73
74 Ok(self.loaded_metadata.as_ref().unwrap())
75 }
76
77 pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
82 if let Some(jwks_uri) = &self.provider.jwks_uri_override {
83 return Ok(jwks_uri);
84 }
85
86 Ok(self.load().await?.jwks_uri())
87 }
88
89 pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
94 if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
95 return Ok(authorization_endpoint);
96 }
97
98 Ok(self.load().await?.authorization_endpoint())
99 }
100
101 pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
106 if let Some(token_endpoint) = &self.provider.token_endpoint_override {
107 return Ok(token_endpoint);
108 }
109
110 Ok(self.load().await?.token_endpoint())
111 }
112
113 pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
118 if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
119 return Ok(userinfo_endpoint);
120 }
121
122 Ok(self.load().await?.userinfo_endpoint())
123 }
124
125 pub async fn pkce_methods(
130 &mut self,
131 ) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
132 let methods = match self.provider.pkce_mode {
133 UpstreamOAuthProviderPkceMode::Auto => self
134 .maybe_discover()
135 .await?
136 .and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
137 UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
138 UpstreamOAuthProviderPkceMode::Disabled => None,
139 };
140
141 Ok(methods)
142 }
143}
144
145#[derive(Debug, Clone, Default)]
151pub struct MetadataCache {
152 cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
153 insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
154}
155
156impl MetadataCache {
157 #[must_use]
158 pub fn new() -> Self {
159 Self::default()
160 }
161
162 #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all)]
172 pub async fn warm_up_and_run<R: RepositoryAccess>(
173 &self,
174 client: &reqwest::Client,
175 interval: std::time::Duration,
176 repository: &mut R,
177 ) -> Result<tokio::task::JoinHandle<()>, R::Error> {
178 let providers = repository.upstream_oauth_provider().all_enabled().await?;
179
180 for provider in providers {
181 let verify = match provider.discovery_mode {
182 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
183 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
184 UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
185 };
186
187 let Some(issuer) = &provider.issuer else {
188 tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
189 continue;
190 };
191
192 if let Err(e) = self.fetch(client, issuer, verify).await {
193 tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
194 }
195 }
196
197 let cache = self.clone();
199 let client = client.clone();
200 Ok(tokio::spawn(async move {
201 loop {
202 tokio::time::sleep(interval).await;
204 LogContext::new("metadata-cache-refresh")
205 .run(|| cache.refresh_all(&client))
206 .await;
207 }
208 }))
209 }
210
211 #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all)]
212 async fn fetch(
213 &self,
214 client: &reqwest::Client,
215 issuer: &str,
216 verify: bool,
217 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
218 if verify {
219 let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
220 let metadata = Arc::new(metadata);
221
222 self.cache
223 .write()
224 .await
225 .insert(issuer.to_owned(), metadata.clone());
226
227 Ok(metadata)
228 } else {
229 let metadata =
230 mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
231 let metadata = Arc::new(metadata);
232
233 self.insecure_cache
234 .write()
235 .await
236 .insert(issuer.to_owned(), metadata.clone());
237
238 Ok(metadata)
239 }
240 }
241
242 #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all)]
248 pub async fn get(
249 &self,
250 client: &reqwest::Client,
251 issuer: &str,
252 verify: bool,
253 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
254 let cache = if verify {
255 self.cache.read().await
256 } else {
257 self.insecure_cache.read().await
258 };
259
260 if let Some(metadata) = cache.get(issuer) {
261 return Ok(Arc::clone(metadata));
262 }
263 drop(cache);
265
266 let metadata = self.fetch(client, issuer, verify).await?;
267 Ok(metadata)
268 }
269
270 #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
271 async fn refresh_all(&self, client: &reqwest::Client) {
272 let keys: Vec<String> = {
274 let cache = self.cache.read().await;
275 cache.keys().cloned().collect()
276 };
277
278 for issuer in keys {
279 if let Err(e) = self.fetch(client, &issuer, true).await {
280 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
281 }
282 }
283
284 let keys: Vec<String> = {
286 let cache = self.insecure_cache.read().await;
287 cache.keys().cloned().collect()
288 };
289
290 for issuer in keys {
291 if let Err(e) = self.fetch(client, &issuer, false).await {
292 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
293 }
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use mas_data_model::{
304 Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
305 UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
306 };
307 use mas_iana::jose::JsonWebSignatureAlg;
308 use oauth2_types::scope::{OPENID, Scope};
309 use ulid::Ulid;
310 use wiremock::{
311 Mock, MockServer, ResponseTemplate,
312 matchers::{method, path},
313 };
314
315 use super::*;
316 use crate::test_utils::setup;
317
318 #[tokio::test]
319 async fn test_metadata_cache() {
320 setup();
321 let mock_server = MockServer::start().await;
322 let http_client = mas_http::reqwest_client();
323
324 let cache = MetadataCache::new();
325
326 cache
328 .get(&http_client, &mock_server.uri(), false)
329 .await
330 .unwrap_err();
331
332 let expected_calls = 3;
333 let mut calls = 0;
334 let _mock_guard = Mock::given(method("GET"))
335 .and(path("/.well-known/openid-configuration"))
336 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
337 "issuer": mock_server.uri(),
338 "authorization_endpoint": "https://example.com/authorize",
339 "token_endpoint": "https://example.com/token",
340 "jwks_uri": "https://example.com/jwks",
341 "userinfo_endpoint": "https://example.com/userinfo",
342 "scopes_supported": ["openid"],
343 "response_types_supported": ["code"],
344 "response_modes_supported": ["query", "fragment"],
345 "grant_types_supported": ["authorization_code"],
346 "subject_types_supported": ["public"],
347 "id_token_signing_alg_values_supported": ["RS256"],
348 })))
349 .expect(expected_calls)
350 .mount(&mock_server)
351 .await;
352
353 cache
355 .get(&http_client, &mock_server.uri(), false)
356 .await
357 .unwrap();
358 calls += 1;
359
360 cache
362 .get(&http_client, &mock_server.uri(), false)
363 .await
364 .unwrap();
365 calls += 0;
366
367 cache
369 .get(&http_client, &mock_server.uri(), true)
370 .await
371 .unwrap_err();
372 calls += 1;
373
374 cache.refresh_all(&http_client).await;
376 calls += 1;
377
378 assert_eq!(calls, expected_calls);
379 }
380
381 #[tokio::test]
382 async fn test_lazy_provider_infos() {
383 setup();
384
385 let mock_server = MockServer::start().await;
386 let http_client = mas_http::reqwest_client();
387
388 let expected_calls = 2;
389 let mut calls = 0;
390 let _mock_guard = Mock::given(method("GET"))
391 .and(path("/.well-known/openid-configuration"))
392 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
393 "issuer": mock_server.uri(),
394 "authorization_endpoint": "https://example.com/authorize",
395 "token_endpoint": "https://example.com/token",
396 "jwks_uri": "https://example.com/jwks",
397 "userinfo_endpoint": "https://example.com/userinfo",
398 "scopes_supported": ["openid"],
399 "response_types_supported": ["code"],
400 "response_modes_supported": ["query", "fragment"],
401 "grant_types_supported": ["authorization_code"],
402 "subject_types_supported": ["public"],
403 "id_token_signing_alg_values_supported": ["RS256"],
404 })))
405 .expect(expected_calls)
406 .mount(&mock_server)
407 .await;
408
409 let clock = MockClock::default();
410 let provider = UpstreamOAuthProvider {
411 id: Ulid::nil(),
412 issuer: Some(mock_server.uri()),
413 human_name: Some("Example Ltd.".to_owned()),
414 brand_name: None,
415 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
416 pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
417 fetch_userinfo: false,
418 userinfo_signed_response_alg: None,
419 jwks_uri_override: None,
420 authorization_endpoint_override: None,
421 scope: Scope::from_iter([OPENID]),
422 userinfo_endpoint_override: None,
423 token_endpoint_override: None,
424 client_id: "client_id".to_owned(),
425 encrypted_client_secret: None,
426 token_endpoint_signing_alg: None,
427 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
428 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
429 response_mode: None,
430 created_at: clock.now(),
431 disabled_at: None,
432 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
433 additional_authorization_parameters: Vec::new(),
434 forward_login_hint: false,
435 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
436 };
437
438 {
440 let cache = MetadataCache::new();
441 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
442 lazy_metadata.maybe_discover().await.unwrap();
443 assert_eq!(
444 lazy_metadata
445 .authorization_endpoint()
446 .await
447 .unwrap()
448 .as_str(),
449 "https://example.com/authorize"
450 );
451 calls += 1;
452 }
453
454 {
456 let provider = UpstreamOAuthProvider {
457 jwks_uri_override: Some("https://example.com/jwks_override".parse().unwrap()),
458 authorization_endpoint_override: Some(
459 "https://example.com/authorize_override".parse().unwrap(),
460 ),
461 token_endpoint_override: Some(
462 "https://example.com/token_override".parse().unwrap(),
463 ),
464 ..provider.clone()
465 };
466 let cache = MetadataCache::new();
467 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
468 assert_eq!(
469 lazy_metadata.jwks_uri().await.unwrap().as_str(),
470 "https://example.com/jwks_override"
471 );
472 assert_eq!(
473 lazy_metadata
474 .authorization_endpoint()
475 .await
476 .unwrap()
477 .as_str(),
478 "https://example.com/authorize_override"
479 );
480 assert_eq!(
481 lazy_metadata.token_endpoint().await.unwrap().as_str(),
482 "https://example.com/token_override"
483 );
484 calls += 0;
486 }
487
488 {
490 let provider = UpstreamOAuthProvider {
491 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
492 ..provider.clone()
493 };
494 let cache = MetadataCache::new();
495 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
496 lazy_metadata.authorization_endpoint().await.unwrap_err();
497 calls += 1;
499 }
500
501 {
503 let provider = UpstreamOAuthProvider {
504 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
505 authorization_endpoint_override: Some(
506 Url::parse("https://example.com/authorize_override").unwrap(),
507 ),
508 token_endpoint_override: None,
509 ..provider.clone()
510 };
511 let cache = MetadataCache::new();
512 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
513 assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
515 assert_eq!(
516 lazy_metadata
517 .authorization_endpoint()
518 .await
519 .unwrap()
520 .as_str(),
521 "https://example.com/authorize_override"
522 );
523 assert!(matches!(
524 lazy_metadata.token_endpoint().await,
525 Err(DiscoveryError::Disabled),
526 ));
527 calls += 0;
529 }
530
531 assert_eq!(calls, expected_calls);
532 }
533}