Skip to main content

mas_handlers/upstream_oauth2/
cache.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{collections::HashMap, sync::Arc};
8
9use mas_context::LogContext;
10use mas_data_model::{
11    UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
12};
13use mas_iana::oauth::PkceCodeChallengeMethod;
14use mas_oidc_client::error::DiscoveryError;
15use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
16use oauth2_types::oidc::VerifiedProviderMetadata;
17use tokio::sync::RwLock;
18use url::Url;
19
20/// A high-level layer over metadata cache and provider configuration, which
21/// resolves endpoint overrides and discovery modes.
22pub struct LazyProviderInfos<'a> {
23    cache: &'a MetadataCache,
24    provider: &'a UpstreamOAuthProvider,
25    client: &'a reqwest::Client,
26    loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
27}
28
29impl<'a> LazyProviderInfos<'a> {
30    pub fn new(
31        cache: &'a MetadataCache,
32        provider: &'a UpstreamOAuthProvider,
33        client: &'a reqwest::Client,
34    ) -> Self {
35        Self {
36            cache,
37            provider,
38            client,
39            loaded_metadata: None,
40        }
41    }
42
43    /// Trigger the discovery process and return the metadata if discovery is
44    /// enabled.
45    pub async fn maybe_discover(
46        &mut self,
47    ) -> Result<Option<&VerifiedProviderMetadata>, DiscoveryError> {
48        match self.load().await {
49            Ok(metadata) => Ok(Some(metadata)),
50            Err(DiscoveryError::Disabled) => Ok(None),
51            Err(e) => Err(e),
52        }
53    }
54
55    async fn load(&mut self) -> Result<&VerifiedProviderMetadata, DiscoveryError> {
56        if self.loaded_metadata.is_none() {
57            let verify = match self.provider.discovery_mode {
58                UpstreamOAuthProviderDiscoveryMode::Oidc => true,
59                UpstreamOAuthProviderDiscoveryMode::Insecure => false,
60                UpstreamOAuthProviderDiscoveryMode::Disabled => {
61                    return Err(DiscoveryError::Disabled);
62                }
63            };
64
65            let Some(issuer) = &self.provider.issuer else {
66                return Err(DiscoveryError::MissingIssuer);
67            };
68
69            let metadata = self.cache.get(self.client, issuer, verify).await?;
70
71            self.loaded_metadata = Some(metadata);
72        }
73
74        Ok(self.loaded_metadata.as_ref().unwrap())
75    }
76
77    /// Get the JWKS URI for the provider.
78    ///
79    /// Uses [`UpstreamOAuthProvider.jwks_uri_override`] if set, otherwise uses
80    /// the one from discovery.
81    pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
82        if let Some(jwks_uri) = &self.provider.jwks_uri_override {
83            return Ok(jwks_uri);
84        }
85
86        Ok(self.load().await?.jwks_uri())
87    }
88
89    /// Get the authorization endpoint for the provider.
90    ///
91    /// Uses [`UpstreamOAuthProvider.authorization_endpoint_override`] if set,
92    /// otherwise uses the one from discovery.
93    pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
94        if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
95            return Ok(authorization_endpoint);
96        }
97
98        Ok(self.load().await?.authorization_endpoint())
99    }
100
101    /// Get the token endpoint for the provider.
102    ///
103    /// Uses [`UpstreamOAuthProvider.token_endpoint_override`] if set, otherwise
104    /// uses the one from discovery.
105    pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
106        if let Some(token_endpoint) = &self.provider.token_endpoint_override {
107            return Ok(token_endpoint);
108        }
109
110        Ok(self.load().await?.token_endpoint())
111    }
112
113    /// Get the userinfo endpoint for the provider.
114    ///
115    /// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
116    /// otherwise uses the one from discovery.
117    pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
118        if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
119            return Ok(userinfo_endpoint);
120        }
121
122        Ok(self.load().await?.userinfo_endpoint())
123    }
124
125    /// Get the PKCE methods supported by the provider.
126    ///
127    /// If the mode is set to auto, it will use the ones from discovery,
128    /// defaulting to none if discovery is disabled.
129    pub async fn pkce_methods(
130        &mut self,
131    ) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
132        let methods = match self.provider.pkce_mode {
133            UpstreamOAuthProviderPkceMode::Auto => self
134                .maybe_discover()
135                .await?
136                .and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
137            UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
138            UpstreamOAuthProviderPkceMode::Disabled => None,
139        };
140
141        Ok(methods)
142    }
143}
144
145/// A simple OIDC metadata cache
146///
147/// It never evicts entries, does not cache failures and has no locking.
148/// It can also be refreshed in the background, and warmed up on startup.
149/// It is good enough for our use case.
150#[derive(Debug, Clone, Default)]
151pub struct MetadataCache {
152    cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
153    insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
154}
155
156impl MetadataCache {
157    #[must_use]
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    /// Warm up the cache by fetching all the known providers from the database
163    /// and inserting them into the cache.
164    ///
165    /// This spawns a background task that will refresh the cache at the given
166    /// interval.
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if the warm up task could not be started.
171    #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all)]
172    pub async fn warm_up_and_run<R: RepositoryAccess>(
173        &self,
174        client: &reqwest::Client,
175        interval: std::time::Duration,
176        repository: &mut R,
177    ) -> Result<tokio::task::JoinHandle<()>, R::Error> {
178        let providers = repository.upstream_oauth_provider().all_enabled().await?;
179
180        for provider in providers {
181            let verify = match provider.discovery_mode {
182                UpstreamOAuthProviderDiscoveryMode::Oidc => true,
183                UpstreamOAuthProviderDiscoveryMode::Insecure => false,
184                UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
185            };
186
187            let Some(issuer) = &provider.issuer else {
188                tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
189                continue;
190            };
191
192            if let Err(e) = self.fetch(client, issuer, verify).await {
193                tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
194            }
195        }
196
197        // Spawn a background task to refresh the cache regularly
198        let cache = self.clone();
199        let client = client.clone();
200        Ok(tokio::spawn(async move {
201            loop {
202                // Re-fetch the known metadata at the given interval
203                tokio::time::sleep(interval).await;
204                LogContext::new("metadata-cache-refresh")
205                    .run(|| cache.refresh_all(&client))
206                    .await;
207            }
208        }))
209    }
210
211    #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all)]
212    async fn fetch(
213        &self,
214        client: &reqwest::Client,
215        issuer: &str,
216        verify: bool,
217    ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
218        if verify {
219            let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
220            let metadata = Arc::new(metadata);
221
222            self.cache
223                .write()
224                .await
225                .insert(issuer.to_owned(), metadata.clone());
226
227            Ok(metadata)
228        } else {
229            let metadata =
230                mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
231            let metadata = Arc::new(metadata);
232
233            self.insecure_cache
234                .write()
235                .await
236                .insert(issuer.to_owned(), metadata.clone());
237
238            Ok(metadata)
239        }
240    }
241
242    /// Get the metadata for the given issuer.
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if the metadata could not be retrieved.
247    #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all)]
248    pub async fn get(
249        &self,
250        client: &reqwest::Client,
251        issuer: &str,
252        verify: bool,
253    ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
254        let cache = if verify {
255            self.cache.read().await
256        } else {
257            self.insecure_cache.read().await
258        };
259
260        if let Some(metadata) = cache.get(issuer) {
261            return Ok(Arc::clone(metadata));
262        }
263        // Drop the cache guard so that we don't deadlock when we try to fetch
264        drop(cache);
265
266        let metadata = self.fetch(client, issuer, verify).await?;
267        Ok(metadata)
268    }
269
270    #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
271    async fn refresh_all(&self, client: &reqwest::Client) {
272        // Grab all the keys first to avoid locking the cache for too long
273        let keys: Vec<String> = {
274            let cache = self.cache.read().await;
275            cache.keys().cloned().collect()
276        };
277
278        for issuer in keys {
279            if let Err(e) = self.fetch(client, &issuer, true).await {
280                tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
281            }
282        }
283
284        // Do the same for the insecure cache
285        let keys: Vec<String> = {
286            let cache = self.insecure_cache.read().await;
287            cache.keys().cloned().collect()
288        };
289
290        for issuer in keys {
291            if let Err(e) = self.fetch(client, &issuer, false).await {
292                tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
293            }
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    // XXX: sadly, we can't test HTTPS requests with wiremock, so we can only test
301    // 'insecure' discovery
302
303    use mas_data_model::{
304        Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
305        UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
306    };
307    use mas_iana::jose::JsonWebSignatureAlg;
308    use oauth2_types::scope::{OPENID, Scope};
309    use ulid::Ulid;
310    use wiremock::{
311        Mock, MockServer, ResponseTemplate,
312        matchers::{method, path},
313    };
314
315    use super::*;
316    use crate::test_utils::setup;
317
318    #[tokio::test]
319    async fn test_metadata_cache() {
320        setup();
321        let mock_server = MockServer::start().await;
322        let http_client = mas_http::reqwest_client();
323
324        let cache = MetadataCache::new();
325
326        // An inexistant issuer should fail
327        cache
328            .get(&http_client, &mock_server.uri(), false)
329            .await
330            .unwrap_err();
331
332        let expected_calls = 3;
333        let mut calls = 0;
334        let _mock_guard = Mock::given(method("GET"))
335            .and(path("/.well-known/openid-configuration"))
336            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
337                "issuer": mock_server.uri(),
338                "authorization_endpoint": "https://example.com/authorize",
339                "token_endpoint": "https://example.com/token",
340                "jwks_uri": "https://example.com/jwks",
341                "userinfo_endpoint": "https://example.com/userinfo",
342                "scopes_supported": ["openid"],
343                "response_types_supported": ["code"],
344                "response_modes_supported": ["query", "fragment"],
345                "grant_types_supported": ["authorization_code"],
346                "subject_types_supported": ["public"],
347                "id_token_signing_alg_values_supported": ["RS256"],
348            })))
349            .expect(expected_calls)
350            .mount(&mock_server)
351            .await;
352
353        // A valid issuer should succeed
354        cache
355            .get(&http_client, &mock_server.uri(), false)
356            .await
357            .unwrap();
358        calls += 1;
359
360        // Calling again should not trigger a new fetch
361        cache
362            .get(&http_client, &mock_server.uri(), false)
363            .await
364            .unwrap();
365        calls += 0;
366
367        // A secure discovery should call but fail because the issuer is insecure
368        cache
369            .get(&http_client, &mock_server.uri(), true)
370            .await
371            .unwrap_err();
372        calls += 1;
373
374        // Calling refresh should refresh all the known issuers
375        cache.refresh_all(&http_client).await;
376        calls += 1;
377
378        assert_eq!(calls, expected_calls);
379    }
380
381    #[tokio::test]
382    async fn test_lazy_provider_infos() {
383        setup();
384
385        let mock_server = MockServer::start().await;
386        let http_client = mas_http::reqwest_client();
387
388        let expected_calls = 2;
389        let mut calls = 0;
390        let _mock_guard = Mock::given(method("GET"))
391            .and(path("/.well-known/openid-configuration"))
392            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
393                "issuer": mock_server.uri(),
394                "authorization_endpoint": "https://example.com/authorize",
395                "token_endpoint": "https://example.com/token",
396                "jwks_uri": "https://example.com/jwks",
397                "userinfo_endpoint": "https://example.com/userinfo",
398                "scopes_supported": ["openid"],
399                "response_types_supported": ["code"],
400                "response_modes_supported": ["query", "fragment"],
401                "grant_types_supported": ["authorization_code"],
402                "subject_types_supported": ["public"],
403                "id_token_signing_alg_values_supported": ["RS256"],
404            })))
405            .expect(expected_calls)
406            .mount(&mock_server)
407            .await;
408
409        let clock = MockClock::default();
410        let provider = UpstreamOAuthProvider {
411            id: Ulid::nil(),
412            issuer: Some(mock_server.uri()),
413            human_name: Some("Example Ltd.".to_owned()),
414            brand_name: None,
415            discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
416            pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
417            fetch_userinfo: false,
418            userinfo_signed_response_alg: None,
419            jwks_uri_override: None,
420            authorization_endpoint_override: None,
421            scope: Scope::from_iter([OPENID]),
422            userinfo_endpoint_override: None,
423            token_endpoint_override: None,
424            client_id: "client_id".to_owned(),
425            encrypted_client_secret: None,
426            token_endpoint_signing_alg: None,
427            token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
428            id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
429            response_mode: None,
430            created_at: clock.now(),
431            disabled_at: None,
432            claims_imports: UpstreamOAuthProviderClaimsImports::default(),
433            additional_authorization_parameters: Vec::new(),
434            forward_login_hint: false,
435            on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
436        };
437
438        // Without any override, it should just use discovery
439        {
440            let cache = MetadataCache::new();
441            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
442            lazy_metadata.maybe_discover().await.unwrap();
443            assert_eq!(
444                lazy_metadata
445                    .authorization_endpoint()
446                    .await
447                    .unwrap()
448                    .as_str(),
449                "https://example.com/authorize"
450            );
451            calls += 1;
452        }
453
454        // Test overriding endpoints
455        {
456            let provider = UpstreamOAuthProvider {
457                jwks_uri_override: Some("https://example.com/jwks_override".parse().unwrap()),
458                authorization_endpoint_override: Some(
459                    "https://example.com/authorize_override".parse().unwrap(),
460                ),
461                token_endpoint_override: Some(
462                    "https://example.com/token_override".parse().unwrap(),
463                ),
464                ..provider.clone()
465            };
466            let cache = MetadataCache::new();
467            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
468            assert_eq!(
469                lazy_metadata.jwks_uri().await.unwrap().as_str(),
470                "https://example.com/jwks_override"
471            );
472            assert_eq!(
473                lazy_metadata
474                    .authorization_endpoint()
475                    .await
476                    .unwrap()
477                    .as_str(),
478                "https://example.com/authorize_override"
479            );
480            assert_eq!(
481                lazy_metadata.token_endpoint().await.unwrap().as_str(),
482                "https://example.com/token_override"
483            );
484            // This shouldn't trigger a new fetch as the endpoint is overriden
485            calls += 0;
486        }
487
488        // Loading an insecure provider with secure discovery should fail
489        {
490            let provider = UpstreamOAuthProvider {
491                discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
492                ..provider.clone()
493            };
494            let cache = MetadataCache::new();
495            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
496            lazy_metadata.authorization_endpoint().await.unwrap_err();
497            // This triggered a fetch, even though it failed
498            calls += 1;
499        }
500
501        // Getting endpoints when discovery is disabled only works for overriden ones
502        {
503            let provider = UpstreamOAuthProvider {
504                discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
505                authorization_endpoint_override: Some(
506                    Url::parse("https://example.com/authorize_override").unwrap(),
507                ),
508                token_endpoint_override: None,
509                ..provider.clone()
510            };
511            let cache = MetadataCache::new();
512            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
513            // This should not fail, but also does nothing
514            assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
515            assert_eq!(
516                lazy_metadata
517                    .authorization_endpoint()
518                    .await
519                    .unwrap()
520                    .as_str(),
521                "https://example.com/authorize_override"
522            );
523            assert!(matches!(
524                lazy_metadata.token_endpoint().await,
525                Err(DiscoveryError::Disabled),
526            ));
527            // This did not trigger a fetch
528            calls += 0;
529        }
530
531        assert_eq!(calls, expected_calls);
532    }
533}