1use std::collections::HashMap;
8
9use axum::{
10 BoxError, Json,
11 extract::{
12 Form, FromRequest,
13 rejection::{FailedToDeserializeForm, FormRejection},
14 },
15 response::IntoResponse,
16};
17use headers::authorization::{Basic, Bearer, Credentials as _};
18use http::{Request, StatusCode};
19use mas_data_model::{Client, JwksOrJwksUri};
20use mas_http::RequestBuilderExt;
21use mas_iana::oauth::OAuthClientAuthenticationMethod;
22use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
23use mas_keystore::Encrypter;
24use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
25use oauth2_types::errors::{ClientError, ClientErrorCode};
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::Value;
28use thiserror::Error;
29
30use crate::record_error;
31
32static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
33
34#[derive(Deserialize)]
35struct AuthorizedForm<F = ()> {
36 client_id: Option<String>,
37 client_secret: Option<String>,
38 client_assertion_type: Option<String>,
39 client_assertion: Option<String>,
40
41 #[serde(flatten)]
42 inner: F,
43}
44
45#[derive(Debug, PartialEq, Eq)]
46pub enum Credentials {
47 None {
48 client_id: String,
49 },
50 ClientSecretBasic {
51 client_id: String,
52 client_secret: String,
53 },
54 ClientSecretPost {
55 client_id: String,
56 client_secret: String,
57 },
58 ClientAssertionJwtBearer {
59 client_id: String,
60 jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
61 },
62 BearerToken {
63 token: String,
64 },
65}
66
67impl Credentials {
68 #[must_use]
70 pub fn client_id(&self) -> Option<&str> {
71 match self {
72 Credentials::None { client_id }
73 | Credentials::ClientSecretBasic { client_id, .. }
74 | Credentials::ClientSecretPost { client_id, .. }
75 | Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
76 Credentials::BearerToken { .. } => None,
77 }
78 }
79
80 #[must_use]
82 pub fn bearer_token(&self) -> Option<&str> {
83 match self {
84 Credentials::BearerToken { token } => Some(token),
85 _ => None,
86 }
87 }
88
89 pub async fn fetch<E>(
96 &self,
97 repo: &mut impl RepositoryAccess<Error = E>,
98 ) -> Result<Option<Client>, E> {
99 let client_id = match self {
100 Credentials::None { client_id }
101 | Credentials::ClientSecretBasic { client_id, .. }
102 | Credentials::ClientSecretPost { client_id, .. }
103 | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
104 Credentials::BearerToken { .. } => return Ok(None),
105 };
106
107 repo.oauth2_client().find_by_client_id(client_id).await
108 }
109
110 #[tracing::instrument(skip_all)]
116 pub async fn verify(
117 &self,
118 http_client: &reqwest::Client,
119 encrypter: &Encrypter,
120 method: &OAuthClientAuthenticationMethod,
121 client: &Client,
122 ) -> Result<(), CredentialsVerificationError> {
123 match (self, method) {
124 (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
125
126 (
127 Credentials::ClientSecretPost { client_secret, .. },
128 OAuthClientAuthenticationMethod::ClientSecretPost,
129 )
130 | (
131 Credentials::ClientSecretBasic { client_secret, .. },
132 OAuthClientAuthenticationMethod::ClientSecretBasic,
133 ) => {
134 let encrypted_client_secret = client
136 .encrypted_client_secret
137 .as_ref()
138 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
139
140 let decrypted_client_secret = encrypter
141 .decrypt_string(encrypted_client_secret)
142 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
143
144 if client_secret.as_bytes() != decrypted_client_secret {
146 return Err(CredentialsVerificationError::ClientSecretMismatch);
147 }
148 }
149
150 (
151 Credentials::ClientAssertionJwtBearer { jwt, .. },
152 OAuthClientAuthenticationMethod::PrivateKeyJwt,
153 ) => {
154 let jwks = client
156 .jwks
157 .as_ref()
158 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
159
160 let jwks = fetch_jwks(http_client, jwks)
161 .await
162 .map_err(CredentialsVerificationError::JwksFetchFailed)?;
163
164 jwt.verify_with_jwks(&jwks)
165 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
166 }
167
168 (
169 Credentials::ClientAssertionJwtBearer { jwt, .. },
170 OAuthClientAuthenticationMethod::ClientSecretJwt,
171 ) => {
172 let encrypted_client_secret = client
174 .encrypted_client_secret
175 .as_ref()
176 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
177
178 let decrypted_client_secret = encrypter
179 .decrypt_string(encrypted_client_secret)
180 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
181
182 jwt.verify_with_shared_secret(decrypted_client_secret)
183 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
184 }
185
186 (_, _) => {
187 return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
188 }
189 }
190 Ok(())
191 }
192}
193
194async fn fetch_jwks(
195 http_client: &reqwest::Client,
196 jwks: &JwksOrJwksUri,
197) -> Result<PublicJsonWebKeySet, BoxError> {
198 let uri = match jwks {
199 JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
200 JwksOrJwksUri::JwksUri(u) => u,
201 };
202
203 let response = http_client
204 .get(uri.as_str())
205 .send_traced()
206 .await?
207 .error_for_status()?
208 .json()
209 .await?;
210
211 Ok(response)
212}
213
214#[derive(Debug, Error)]
215pub enum CredentialsVerificationError {
216 #[error("failed to decrypt client credentials")]
217 DecryptionError,
218
219 #[error("invalid client configuration")]
220 InvalidClientConfig,
221
222 #[error("client secret did not match")]
223 ClientSecretMismatch,
224
225 #[error("authentication method mismatch")]
226 AuthenticationMethodMismatch,
227
228 #[error("invalid assertion signature")]
229 InvalidAssertionSignature,
230
231 #[error("failed to fetch jwks")]
232 JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
233}
234
235impl CredentialsVerificationError {
236 #[must_use]
238 pub fn is_internal(&self) -> bool {
239 matches!(
240 self,
241 Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
242 )
243 }
244}
245
246#[derive(Debug, PartialEq, Eq)]
247pub struct ClientAuthorization<F = ()> {
248 pub credentials: Credentials,
249 pub form: Option<F>,
250}
251
252impl<F> ClientAuthorization<F> {
253 #[must_use]
255 pub fn client_id(&self) -> Option<&str> {
256 self.credentials.client_id()
257 }
258}
259
260#[derive(Debug, Error)]
261pub enum ClientAuthorizationError {
262 #[error("Invalid Authorization header")]
263 InvalidHeader,
264
265 #[error("Could not deserialize request body")]
266 BadForm(#[source] FailedToDeserializeForm),
267
268 #[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
269 ClientIdMismatch { credential: String, form: String },
270
271 #[error("Unsupported client_assertion_type: {client_assertion_type}")]
272 UnsupportedClientAssertion { client_assertion_type: String },
273
274 #[error("No credentials were presented")]
275 MissingCredentials,
276
277 #[error("Invalid request")]
278 InvalidRequest,
279
280 #[error("Invalid client_assertion")]
281 InvalidAssertion,
282
283 #[error(transparent)]
284 Internal(Box<dyn std::error::Error>),
285}
286
287impl IntoResponse for ClientAuthorizationError {
288 fn into_response(self) -> axum::response::Response {
289 let sentry_event_id = record_error!(self, Self::Internal(_));
290 match &self {
291 ClientAuthorizationError::InvalidHeader => (
292 StatusCode::BAD_REQUEST,
293 sentry_event_id,
294 Json(ClientError::new(
295 ClientErrorCode::InvalidRequest,
296 "Invalid Authorization header",
297 )),
298 ),
299
300 ClientAuthorizationError::BadForm(err) => (
301 StatusCode::BAD_REQUEST,
302 sentry_event_id,
303 Json(
304 ClientError::from(ClientErrorCode::InvalidRequest)
305 .with_description(format!("{err}")),
306 ),
307 ),
308
309 ClientAuthorizationError::ClientIdMismatch { .. } => (
310 StatusCode::BAD_REQUEST,
311 sentry_event_id,
312 Json(
313 ClientError::from(ClientErrorCode::InvalidGrant)
314 .with_description(format!("{self}")),
315 ),
316 ),
317
318 ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
319 StatusCode::BAD_REQUEST,
320 sentry_event_id,
321 Json(
322 ClientError::from(ClientErrorCode::InvalidRequest)
323 .with_description(format!("{self}")),
324 ),
325 ),
326
327 ClientAuthorizationError::MissingCredentials => (
328 StatusCode::BAD_REQUEST,
329 sentry_event_id,
330 Json(ClientError::new(
331 ClientErrorCode::InvalidRequest,
332 "No credentials were presented",
333 )),
334 ),
335
336 ClientAuthorizationError::InvalidRequest => (
337 StatusCode::BAD_REQUEST,
338 sentry_event_id,
339 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
340 ),
341
342 ClientAuthorizationError::InvalidAssertion => (
343 StatusCode::BAD_REQUEST,
344 sentry_event_id,
345 Json(ClientError::new(
346 ClientErrorCode::InvalidRequest,
347 "Invalid client_assertion",
348 )),
349 ),
350
351 ClientAuthorizationError::Internal(e) => (
352 StatusCode::INTERNAL_SERVER_ERROR,
353 sentry_event_id,
354 Json(
355 ClientError::from(ClientErrorCode::ServerError)
356 .with_description(format!("{e}")),
357 ),
358 ),
359 }
360 .into_response()
361 }
362}
363
364impl<S, F> FromRequest<S> for ClientAuthorization<F>
365where
366 F: DeserializeOwned,
367 S: Send + Sync,
368{
369 type Rejection = ClientAuthorizationError;
370
371 async fn from_request(
372 req: Request<axum::body::Body>,
373 state: &S,
374 ) -> Result<Self, Self::Rejection> {
375 enum Authorization {
376 Basic(String, String),
377 Bearer(String),
378 }
379
380 let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
383 let bytes = header.as_bytes();
384 if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
385 let Some(decoded) = Basic::decode(header) else {
386 return Err(ClientAuthorizationError::InvalidHeader);
387 };
388
389 Some(Authorization::Basic(
390 decoded.username().to_owned(),
391 decoded.password().to_owned(),
392 ))
393 } else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
394 let Some(decoded) = Bearer::decode(header) else {
395 return Err(ClientAuthorizationError::InvalidHeader);
396 };
397
398 Some(Authorization::Bearer(decoded.token().to_owned()))
399 } else {
400 return Err(ClientAuthorizationError::InvalidHeader);
401 }
402 } else {
403 None
404 };
405
406 let (
408 client_id_from_form,
409 client_secret_from_form,
410 client_assertion_type,
411 client_assertion,
412 form,
413 ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
414 Ok(Form(form)) => (
415 form.client_id,
416 form.client_secret,
417 form.client_assertion_type,
418 form.client_assertion,
419 Some(form.inner),
420 ),
421 Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
423 Err(FormRejection::FailedToDeserializeForm(err)) => {
425 return Err(ClientAuthorizationError::BadForm(err));
426 }
427 Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
429 };
430
431 let credentials = match (
433 authorization,
434 client_id_from_form,
435 client_secret_from_form,
436 client_assertion_type,
437 client_assertion,
438 ) {
439 (
440 Some(Authorization::Basic(client_id, client_secret)),
441 client_id_from_form,
442 None,
443 None,
444 None,
445 ) => {
446 if let Some(client_id_from_form) = client_id_from_form {
447 if client_id != client_id_from_form {
449 return Err(ClientAuthorizationError::ClientIdMismatch {
450 credential: client_id,
451 form: client_id_from_form,
452 });
453 }
454 }
455
456 Credentials::ClientSecretBasic {
457 client_id,
458 client_secret,
459 }
460 }
461
462 (None, Some(client_id), Some(client_secret), None, None) => {
463 Credentials::ClientSecretPost {
465 client_id,
466 client_secret,
467 }
468 }
469
470 (None, Some(client_id), None, None, None) => {
471 Credentials::None { client_id }
473 }
474
475 (
476 None,
477 client_id_from_form,
478 None,
479 Some(client_assertion_type),
480 Some(client_assertion),
481 ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
482 let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
484 .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
485
486 let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
487 client_id.clone()
488 } else {
489 return Err(ClientAuthorizationError::InvalidAssertion);
490 };
491
492 if let Some(client_id_from_form) = client_id_from_form {
493 if client_id != client_id_from_form {
495 return Err(ClientAuthorizationError::ClientIdMismatch {
496 credential: client_id,
497 form: client_id_from_form,
498 });
499 }
500 }
501
502 Credentials::ClientAssertionJwtBearer {
503 client_id,
504 jwt: Box::new(jwt),
505 }
506 }
507
508 (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
509 return Err(ClientAuthorizationError::UnsupportedClientAssertion {
511 client_assertion_type,
512 });
513 }
514
515 (Some(Authorization::Bearer(token)), None, None, None, None) => {
516 Credentials::BearerToken { token }
518 }
519
520 (None, None, None, None, None) => {
521 return Err(ClientAuthorizationError::MissingCredentials);
523 }
524
525 _ => {
526 return Err(ClientAuthorizationError::InvalidRequest);
528 }
529 };
530
531 Ok(ClientAuthorization { credentials, form })
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use axum::body::Body;
538 use http::{Method, Request};
539
540 use super::*;
541
542 #[tokio::test]
543 async fn none_test() {
544 let req = Request::builder()
545 .method(Method::POST)
546 .header(
547 http::header::CONTENT_TYPE,
548 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
549 )
550 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
551 .unwrap();
552
553 assert_eq!(
554 ClientAuthorization::<serde_json::Value>::from_request(req, &())
555 .await
556 .unwrap(),
557 ClientAuthorization {
558 credentials: Credentials::None {
559 client_id: "client-id".to_owned(),
560 },
561 form: Some(serde_json::json!({"foo": "bar"})),
562 }
563 );
564 }
565
566 #[tokio::test]
567 async fn client_secret_basic_test() {
568 let req = Request::builder()
569 .method(Method::POST)
570 .header(
571 http::header::CONTENT_TYPE,
572 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
573 )
574 .header(
575 http::header::AUTHORIZATION,
576 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
577 )
578 .body(Body::new("foo=bar".to_owned()))
579 .unwrap();
580
581 assert_eq!(
582 ClientAuthorization::<serde_json::Value>::from_request(req, &())
583 .await
584 .unwrap(),
585 ClientAuthorization {
586 credentials: Credentials::ClientSecretBasic {
587 client_id: "client-id".to_owned(),
588 client_secret: "client-secret".to_owned(),
589 },
590 form: Some(serde_json::json!({"foo": "bar"})),
591 }
592 );
593
594 let req = Request::builder()
596 .method(Method::POST)
597 .header(
598 http::header::CONTENT_TYPE,
599 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
600 )
601 .header(
602 http::header::AUTHORIZATION,
603 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
604 )
605 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
606 .unwrap();
607
608 assert_eq!(
609 ClientAuthorization::<serde_json::Value>::from_request(req, &())
610 .await
611 .unwrap(),
612 ClientAuthorization {
613 credentials: Credentials::ClientSecretBasic {
614 client_id: "client-id".to_owned(),
615 client_secret: "client-secret".to_owned(),
616 },
617 form: Some(serde_json::json!({"foo": "bar"})),
618 }
619 );
620
621 let req = Request::builder()
623 .method(Method::POST)
624 .header(
625 http::header::CONTENT_TYPE,
626 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
627 )
628 .header(
629 http::header::AUTHORIZATION,
630 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
631 )
632 .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
633 .unwrap();
634
635 assert!(matches!(
636 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
637 Err(ClientAuthorizationError::ClientIdMismatch { .. }),
638 ));
639
640 let req = Request::builder()
642 .method(Method::POST)
643 .header(
644 http::header::CONTENT_TYPE,
645 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
646 )
647 .header(http::header::AUTHORIZATION, "Basic invalid")
648 .body(Body::new("foo=bar".to_owned()))
649 .unwrap();
650
651 assert!(matches!(
652 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
653 Err(ClientAuthorizationError::InvalidHeader),
654 ));
655 }
656
657 #[tokio::test]
658 async fn client_secret_post_test() {
659 let req = Request::builder()
660 .method(Method::POST)
661 .header(
662 http::header::CONTENT_TYPE,
663 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
664 )
665 .body(Body::new(
666 "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
667 ))
668 .unwrap();
669
670 assert_eq!(
671 ClientAuthorization::<serde_json::Value>::from_request(req, &())
672 .await
673 .unwrap(),
674 ClientAuthorization {
675 credentials: Credentials::ClientSecretPost {
676 client_id: "client-id".to_owned(),
677 client_secret: "client-secret".to_owned(),
678 },
679 form: Some(serde_json::json!({"foo": "bar"})),
680 }
681 );
682 }
683
684 #[tokio::test]
685 async fn client_assertion_test() {
686 let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
688 let body = Body::new(format!(
689 "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
690 ));
691
692 let req = Request::builder()
693 .method(Method::POST)
694 .header(
695 http::header::CONTENT_TYPE,
696 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
697 )
698 .body(body)
699 .unwrap();
700
701 let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
702 .await
703 .unwrap();
704 assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
705
706 let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
707 panic!("expected a JWT client_assertion");
708 };
709
710 assert_eq!(client_id, "client-id");
711 jwt.verify_with_shared_secret(b"client-secret".to_vec())
712 .unwrap();
713 }
714
715 #[tokio::test]
716 async fn bearer_token_test() {
717 let req = Request::builder()
718 .method(Method::POST)
719 .header(
720 http::header::CONTENT_TYPE,
721 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
722 )
723 .header(http::header::AUTHORIZATION, "Bearer token")
724 .body(Body::new("foo=bar".to_owned()))
725 .unwrap();
726
727 assert_eq!(
728 ClientAuthorization::<serde_json::Value>::from_request(req, &())
729 .await
730 .unwrap(),
731 ClientAuthorization {
732 credentials: Credentials::BearerToken {
733 token: "token".to_owned(),
734 },
735 form: Some(serde_json::json!({"foo": "bar"})),
736 }
737 );
738 }
739}