mas_axum_utils/
client_authorization.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::HashMap;
8
9use axum::{
10    BoxError, Json,
11    extract::{
12        Form, FromRequest,
13        rejection::{FailedToDeserializeForm, FormRejection},
14    },
15    response::IntoResponse,
16};
17use headers::authorization::{Basic, Bearer, Credentials as _};
18use http::{Request, StatusCode};
19use mas_data_model::{Client, JwksOrJwksUri};
20use mas_http::RequestBuilderExt;
21use mas_iana::oauth::OAuthClientAuthenticationMethod;
22use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
23use mas_keystore::Encrypter;
24use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
25use oauth2_types::errors::{ClientError, ClientErrorCode};
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::Value;
28use thiserror::Error;
29
30use crate::record_error;
31
32static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
33
34#[derive(Deserialize)]
35struct AuthorizedForm<F = ()> {
36    client_id: Option<String>,
37    client_secret: Option<String>,
38    client_assertion_type: Option<String>,
39    client_assertion: Option<String>,
40
41    #[serde(flatten)]
42    inner: F,
43}
44
45#[derive(Debug, PartialEq, Eq)]
46pub enum Credentials {
47    None {
48        client_id: String,
49    },
50    ClientSecretBasic {
51        client_id: String,
52        client_secret: String,
53    },
54    ClientSecretPost {
55        client_id: String,
56        client_secret: String,
57    },
58    ClientAssertionJwtBearer {
59        client_id: String,
60        jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
61    },
62    BearerToken {
63        token: String,
64    },
65}
66
67impl Credentials {
68    /// Get the `client_id` of the credentials
69    #[must_use]
70    pub fn client_id(&self) -> Option<&str> {
71        match self {
72            Credentials::None { client_id }
73            | Credentials::ClientSecretBasic { client_id, .. }
74            | Credentials::ClientSecretPost { client_id, .. }
75            | Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
76            Credentials::BearerToken { .. } => None,
77        }
78    }
79
80    /// Get the bearer token from the credentials.
81    #[must_use]
82    pub fn bearer_token(&self) -> Option<&str> {
83        match self {
84            Credentials::BearerToken { token } => Some(token),
85            _ => None,
86        }
87    }
88
89    /// Fetch the client from the database
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if the client could not be found or if the underlying
94    /// repository errored.
95    pub async fn fetch<E>(
96        &self,
97        repo: &mut impl RepositoryAccess<Error = E>,
98    ) -> Result<Option<Client>, E> {
99        let client_id = match self {
100            Credentials::None { client_id }
101            | Credentials::ClientSecretBasic { client_id, .. }
102            | Credentials::ClientSecretPost { client_id, .. }
103            | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
104            Credentials::BearerToken { .. } => return Ok(None),
105        };
106
107        repo.oauth2_client().find_by_client_id(client_id).await
108    }
109
110    /// Verify credentials presented by the client for authentication
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the credentials are invalid.
115    #[tracing::instrument(skip_all)]
116    pub async fn verify(
117        &self,
118        http_client: &reqwest::Client,
119        encrypter: &Encrypter,
120        method: &OAuthClientAuthenticationMethod,
121        client: &Client,
122    ) -> Result<(), CredentialsVerificationError> {
123        match (self, method) {
124            (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
125
126            (
127                Credentials::ClientSecretPost { client_secret, .. },
128                OAuthClientAuthenticationMethod::ClientSecretPost,
129            )
130            | (
131                Credentials::ClientSecretBasic { client_secret, .. },
132                OAuthClientAuthenticationMethod::ClientSecretBasic,
133            ) => {
134                // Decrypt the client_secret
135                let encrypted_client_secret = client
136                    .encrypted_client_secret
137                    .as_ref()
138                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
139
140                let decrypted_client_secret = encrypter
141                    .decrypt_string(encrypted_client_secret)
142                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
143
144                // Check if the client_secret matches
145                if client_secret.as_bytes() != decrypted_client_secret {
146                    return Err(CredentialsVerificationError::ClientSecretMismatch);
147                }
148            }
149
150            (
151                Credentials::ClientAssertionJwtBearer { jwt, .. },
152                OAuthClientAuthenticationMethod::PrivateKeyJwt,
153            ) => {
154                // Get the client JWKS
155                let jwks = client
156                    .jwks
157                    .as_ref()
158                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
159
160                let jwks = fetch_jwks(http_client, jwks)
161                    .await
162                    .map_err(CredentialsVerificationError::JwksFetchFailed)?;
163
164                jwt.verify_with_jwks(&jwks)
165                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
166            }
167
168            (
169                Credentials::ClientAssertionJwtBearer { jwt, .. },
170                OAuthClientAuthenticationMethod::ClientSecretJwt,
171            ) => {
172                // Decrypt the client_secret
173                let encrypted_client_secret = client
174                    .encrypted_client_secret
175                    .as_ref()
176                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
177
178                let decrypted_client_secret = encrypter
179                    .decrypt_string(encrypted_client_secret)
180                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
181
182                jwt.verify_with_shared_secret(decrypted_client_secret)
183                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
184            }
185
186            (_, _) => {
187                return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
188            }
189        }
190        Ok(())
191    }
192}
193
194async fn fetch_jwks(
195    http_client: &reqwest::Client,
196    jwks: &JwksOrJwksUri,
197) -> Result<PublicJsonWebKeySet, BoxError> {
198    let uri = match jwks {
199        JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
200        JwksOrJwksUri::JwksUri(u) => u,
201    };
202
203    let response = http_client
204        .get(uri.as_str())
205        .send_traced()
206        .await?
207        .error_for_status()?
208        .json()
209        .await?;
210
211    Ok(response)
212}
213
214#[derive(Debug, Error)]
215pub enum CredentialsVerificationError {
216    #[error("failed to decrypt client credentials")]
217    DecryptionError,
218
219    #[error("invalid client configuration")]
220    InvalidClientConfig,
221
222    #[error("client secret did not match")]
223    ClientSecretMismatch,
224
225    #[error("authentication method mismatch")]
226    AuthenticationMethodMismatch,
227
228    #[error("invalid assertion signature")]
229    InvalidAssertionSignature,
230
231    #[error("failed to fetch jwks")]
232    JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
233}
234
235impl CredentialsVerificationError {
236    /// Returns true if the error is an internal error, not caused by the client
237    #[must_use]
238    pub fn is_internal(&self) -> bool {
239        matches!(
240            self,
241            Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
242        )
243    }
244}
245
246#[derive(Debug, PartialEq, Eq)]
247pub struct ClientAuthorization<F = ()> {
248    pub credentials: Credentials,
249    pub form: Option<F>,
250}
251
252impl<F> ClientAuthorization<F> {
253    /// Get the `client_id` from the credentials.
254    #[must_use]
255    pub fn client_id(&self) -> Option<&str> {
256        self.credentials.client_id()
257    }
258}
259
260#[derive(Debug, Error)]
261pub enum ClientAuthorizationError {
262    #[error("Invalid Authorization header")]
263    InvalidHeader,
264
265    #[error("Could not deserialize request body")]
266    BadForm(#[source] FailedToDeserializeForm),
267
268    #[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
269    ClientIdMismatch { credential: String, form: String },
270
271    #[error("Unsupported client_assertion_type: {client_assertion_type}")]
272    UnsupportedClientAssertion { client_assertion_type: String },
273
274    #[error("No credentials were presented")]
275    MissingCredentials,
276
277    #[error("Invalid request")]
278    InvalidRequest,
279
280    #[error("Invalid client_assertion")]
281    InvalidAssertion,
282
283    #[error(transparent)]
284    Internal(Box<dyn std::error::Error>),
285}
286
287impl IntoResponse for ClientAuthorizationError {
288    fn into_response(self) -> axum::response::Response {
289        let sentry_event_id = record_error!(self, Self::Internal(_));
290        match &self {
291            ClientAuthorizationError::InvalidHeader => (
292                StatusCode::BAD_REQUEST,
293                sentry_event_id,
294                Json(ClientError::new(
295                    ClientErrorCode::InvalidRequest,
296                    "Invalid Authorization header",
297                )),
298            ),
299
300            ClientAuthorizationError::BadForm(err) => (
301                StatusCode::BAD_REQUEST,
302                sentry_event_id,
303                Json(
304                    ClientError::from(ClientErrorCode::InvalidRequest)
305                        .with_description(format!("{err}")),
306                ),
307            ),
308
309            ClientAuthorizationError::ClientIdMismatch { .. } => (
310                StatusCode::BAD_REQUEST,
311                sentry_event_id,
312                Json(
313                    ClientError::from(ClientErrorCode::InvalidGrant)
314                        .with_description(format!("{self}")),
315                ),
316            ),
317
318            ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
319                StatusCode::BAD_REQUEST,
320                sentry_event_id,
321                Json(
322                    ClientError::from(ClientErrorCode::InvalidRequest)
323                        .with_description(format!("{self}")),
324                ),
325            ),
326
327            ClientAuthorizationError::MissingCredentials => (
328                StatusCode::BAD_REQUEST,
329                sentry_event_id,
330                Json(ClientError::new(
331                    ClientErrorCode::InvalidRequest,
332                    "No credentials were presented",
333                )),
334            ),
335
336            ClientAuthorizationError::InvalidRequest => (
337                StatusCode::BAD_REQUEST,
338                sentry_event_id,
339                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
340            ),
341
342            ClientAuthorizationError::InvalidAssertion => (
343                StatusCode::BAD_REQUEST,
344                sentry_event_id,
345                Json(ClientError::new(
346                    ClientErrorCode::InvalidRequest,
347                    "Invalid client_assertion",
348                )),
349            ),
350
351            ClientAuthorizationError::Internal(e) => (
352                StatusCode::INTERNAL_SERVER_ERROR,
353                sentry_event_id,
354                Json(
355                    ClientError::from(ClientErrorCode::ServerError)
356                        .with_description(format!("{e}")),
357                ),
358            ),
359        }
360        .into_response()
361    }
362}
363
364impl<S, F> FromRequest<S> for ClientAuthorization<F>
365where
366    F: DeserializeOwned,
367    S: Send + Sync,
368{
369    type Rejection = ClientAuthorizationError;
370
371    async fn from_request(
372        req: Request<axum::body::Body>,
373        state: &S,
374    ) -> Result<Self, Self::Rejection> {
375        enum Authorization {
376            Basic(String, String),
377            Bearer(String),
378        }
379
380        // Sadly, the typed-header 'Authorization' doesn't let us check for both
381        // Basic and Bearer at the same time, so we need to parse them manually
382        let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
383            let bytes = header.as_bytes();
384            if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
385                let Some(decoded) = Basic::decode(header) else {
386                    return Err(ClientAuthorizationError::InvalidHeader);
387                };
388
389                Some(Authorization::Basic(
390                    decoded.username().to_owned(),
391                    decoded.password().to_owned(),
392                ))
393            } else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
394                let Some(decoded) = Bearer::decode(header) else {
395                    return Err(ClientAuthorizationError::InvalidHeader);
396                };
397
398                Some(Authorization::Bearer(decoded.token().to_owned()))
399            } else {
400                return Err(ClientAuthorizationError::InvalidHeader);
401            }
402        } else {
403            None
404        };
405
406        // Take the form value
407        let (
408            client_id_from_form,
409            client_secret_from_form,
410            client_assertion_type,
411            client_assertion,
412            form,
413        ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
414            Ok(Form(form)) => (
415                form.client_id,
416                form.client_secret,
417                form.client_assertion_type,
418                form.client_assertion,
419                Some(form.inner),
420            ),
421            // If it is not a form, continue
422            Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
423            // If the form could not be read, return a Bad Request error
424            Err(FormRejection::FailedToDeserializeForm(err)) => {
425                return Err(ClientAuthorizationError::BadForm(err));
426            }
427            // Other errors (body read twice, byte stream broke) return an internal error
428            Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
429        };
430
431        // And now, figure out the actual auth method
432        let credentials = match (
433            authorization,
434            client_id_from_form,
435            client_secret_from_form,
436            client_assertion_type,
437            client_assertion,
438        ) {
439            (
440                Some(Authorization::Basic(client_id, client_secret)),
441                client_id_from_form,
442                None,
443                None,
444                None,
445            ) => {
446                if let Some(client_id_from_form) = client_id_from_form {
447                    // If the client_id was in the body, verify it matches with the header
448                    if client_id != client_id_from_form {
449                        return Err(ClientAuthorizationError::ClientIdMismatch {
450                            credential: client_id,
451                            form: client_id_from_form,
452                        });
453                    }
454                }
455
456                Credentials::ClientSecretBasic {
457                    client_id,
458                    client_secret,
459                }
460            }
461
462            (None, Some(client_id), Some(client_secret), None, None) => {
463                // Got both client_id and client_secret from the form
464                Credentials::ClientSecretPost {
465                    client_id,
466                    client_secret,
467                }
468            }
469
470            (None, Some(client_id), None, None, None) => {
471                // Only got a client_id in the form
472                Credentials::None { client_id }
473            }
474
475            (
476                None,
477                client_id_from_form,
478                None,
479                Some(client_assertion_type),
480                Some(client_assertion),
481            ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
482                // Got a JWT bearer client_assertion
483                let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
484                    .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
485
486                let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
487                    client_id.clone()
488                } else {
489                    return Err(ClientAuthorizationError::InvalidAssertion);
490                };
491
492                if let Some(client_id_from_form) = client_id_from_form {
493                    // If the client_id was in the body, verify it matches the one in the JWT
494                    if client_id != client_id_from_form {
495                        return Err(ClientAuthorizationError::ClientIdMismatch {
496                            credential: client_id,
497                            form: client_id_from_form,
498                        });
499                    }
500                }
501
502                Credentials::ClientAssertionJwtBearer {
503                    client_id,
504                    jwt: Box::new(jwt),
505                }
506            }
507
508            (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
509                // Got another unsupported client_assertion
510                return Err(ClientAuthorizationError::UnsupportedClientAssertion {
511                    client_assertion_type,
512                });
513            }
514
515            (Some(Authorization::Bearer(token)), None, None, None, None) => {
516                // Got a bearer token
517                Credentials::BearerToken { token }
518            }
519
520            (None, None, None, None, None) => {
521                // Special case when there are no credentials anywhere
522                return Err(ClientAuthorizationError::MissingCredentials);
523            }
524
525            _ => {
526                // Every other combination is an invalid request
527                return Err(ClientAuthorizationError::InvalidRequest);
528            }
529        };
530
531        Ok(ClientAuthorization { credentials, form })
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use axum::body::Body;
538    use http::{Method, Request};
539
540    use super::*;
541
542    #[tokio::test]
543    async fn none_test() {
544        let req = Request::builder()
545            .method(Method::POST)
546            .header(
547                http::header::CONTENT_TYPE,
548                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
549            )
550            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
551            .unwrap();
552
553        assert_eq!(
554            ClientAuthorization::<serde_json::Value>::from_request(req, &())
555                .await
556                .unwrap(),
557            ClientAuthorization {
558                credentials: Credentials::None {
559                    client_id: "client-id".to_owned(),
560                },
561                form: Some(serde_json::json!({"foo": "bar"})),
562            }
563        );
564    }
565
566    #[tokio::test]
567    async fn client_secret_basic_test() {
568        let req = Request::builder()
569            .method(Method::POST)
570            .header(
571                http::header::CONTENT_TYPE,
572                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
573            )
574            .header(
575                http::header::AUTHORIZATION,
576                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
577            )
578            .body(Body::new("foo=bar".to_owned()))
579            .unwrap();
580
581        assert_eq!(
582            ClientAuthorization::<serde_json::Value>::from_request(req, &())
583                .await
584                .unwrap(),
585            ClientAuthorization {
586                credentials: Credentials::ClientSecretBasic {
587                    client_id: "client-id".to_owned(),
588                    client_secret: "client-secret".to_owned(),
589                },
590                form: Some(serde_json::json!({"foo": "bar"})),
591            }
592        );
593
594        // client_id in both header and body
595        let req = Request::builder()
596            .method(Method::POST)
597            .header(
598                http::header::CONTENT_TYPE,
599                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
600            )
601            .header(
602                http::header::AUTHORIZATION,
603                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
604            )
605            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
606            .unwrap();
607
608        assert_eq!(
609            ClientAuthorization::<serde_json::Value>::from_request(req, &())
610                .await
611                .unwrap(),
612            ClientAuthorization {
613                credentials: Credentials::ClientSecretBasic {
614                    client_id: "client-id".to_owned(),
615                    client_secret: "client-secret".to_owned(),
616                },
617                form: Some(serde_json::json!({"foo": "bar"})),
618            }
619        );
620
621        // client_id in both header and body mismatch
622        let req = Request::builder()
623            .method(Method::POST)
624            .header(
625                http::header::CONTENT_TYPE,
626                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
627            )
628            .header(
629                http::header::AUTHORIZATION,
630                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
631            )
632            .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
633            .unwrap();
634
635        assert!(matches!(
636            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
637            Err(ClientAuthorizationError::ClientIdMismatch { .. }),
638        ));
639
640        // Invalid header
641        let req = Request::builder()
642            .method(Method::POST)
643            .header(
644                http::header::CONTENT_TYPE,
645                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
646            )
647            .header(http::header::AUTHORIZATION, "Basic invalid")
648            .body(Body::new("foo=bar".to_owned()))
649            .unwrap();
650
651        assert!(matches!(
652            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
653            Err(ClientAuthorizationError::InvalidHeader),
654        ));
655    }
656
657    #[tokio::test]
658    async fn client_secret_post_test() {
659        let req = Request::builder()
660            .method(Method::POST)
661            .header(
662                http::header::CONTENT_TYPE,
663                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
664            )
665            .body(Body::new(
666                "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
667            ))
668            .unwrap();
669
670        assert_eq!(
671            ClientAuthorization::<serde_json::Value>::from_request(req, &())
672                .await
673                .unwrap(),
674            ClientAuthorization {
675                credentials: Credentials::ClientSecretPost {
676                    client_id: "client-id".to_owned(),
677                    client_secret: "client-secret".to_owned(),
678                },
679                form: Some(serde_json::json!({"foo": "bar"})),
680            }
681        );
682    }
683
684    #[tokio::test]
685    async fn client_assertion_test() {
686        // Signed with client_secret = "client-secret"
687        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
688        let body = Body::new(format!(
689            "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
690        ));
691
692        let req = Request::builder()
693            .method(Method::POST)
694            .header(
695                http::header::CONTENT_TYPE,
696                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
697            )
698            .body(body)
699            .unwrap();
700
701        let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
702            .await
703            .unwrap();
704        assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
705
706        let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
707            panic!("expected a JWT client_assertion");
708        };
709
710        assert_eq!(client_id, "client-id");
711        jwt.verify_with_shared_secret(b"client-secret".to_vec())
712            .unwrap();
713    }
714
715    #[tokio::test]
716    async fn bearer_token_test() {
717        let req = Request::builder()
718            .method(Method::POST)
719            .header(
720                http::header::CONTENT_TYPE,
721                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
722            )
723            .header(http::header::AUTHORIZATION, "Bearer token")
724            .body(Body::new("foo=bar".to_owned()))
725            .unwrap();
726
727        assert_eq!(
728            ClientAuthorization::<serde_json::Value>::from_request(req, &())
729                .await
730                .unwrap(),
731            ClientAuthorization {
732                credentials: Credentials::BearerToken {
733                    token: "token".to_owned(),
734                },
735                form: Some(serde_json::json!({"foo": "bar"})),
736            }
737        );
738    }
739}