mas_oidc_client/requests/
authorization_code.rs1use std::{collections::HashSet, num::NonZeroU32};
12
13use base64ct::{Base64UrlUnpadded, Encoding};
14use chrono::{DateTime, Utc};
15use language_tags::LanguageTag;
16use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
17use mas_jose::claims::{self, TokenHash};
18use oauth2_types::{
19 pkce,
20 prelude::CodeChallengeMethodExt,
21 requests::{
22 AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
23 Display, Prompt, ResponseMode,
24 },
25 scope::{OPENID, Scope},
26};
27use rand::{
28 Rng,
29 distributions::{Alphanumeric, DistString},
30};
31use serde::Serialize;
32use url::Url;
33
34use super::jose::JwtVerificationData;
35use crate::{
36 error::{AuthorizationError, IdTokenError, TokenAuthorizationCodeError},
37 requests::{jose::verify_id_token, token::request_access_token},
38 types::{IdToken, client_credentials::ClientCredentials},
39};
40
41#[derive(Debug, Clone)]
43pub struct AuthorizationRequestData {
44 pub client_id: String,
46
47 pub scope: Scope,
52
53 pub redirect_uri: Url,
57
58 pub code_challenge_methods_supported: Option<Vec<PkceCodeChallengeMethod>>,
63
64 pub display: Option<Display>,
67
68 pub prompt: Option<Vec<Prompt>>,
73
74 pub max_age: Option<NonZeroU32>,
77
78 pub ui_locales: Option<Vec<LanguageTag>>,
80
81 pub id_token_hint: Option<String>,
85
86 pub login_hint: Option<String>,
89
90 pub acr_values: Option<HashSet<String>>,
92
93 pub response_mode: Option<ResponseMode>,
95}
96
97impl AuthorizationRequestData {
98 #[must_use]
101 pub fn new(client_id: String, scope: Scope, redirect_uri: Url) -> Self {
102 Self {
103 client_id,
104 scope,
105 redirect_uri,
106 code_challenge_methods_supported: None,
107 display: None,
108 prompt: None,
109 max_age: None,
110 ui_locales: None,
111 id_token_hint: None,
112 login_hint: None,
113 acr_values: None,
114 response_mode: None,
115 }
116 }
117
118 #[must_use]
121 pub fn with_code_challenge_methods_supported(
122 mut self,
123 code_challenge_methods_supported: Vec<PkceCodeChallengeMethod>,
124 ) -> Self {
125 self.code_challenge_methods_supported = Some(code_challenge_methods_supported);
126 self
127 }
128
129 #[must_use]
131 pub fn with_display(mut self, display: Display) -> Self {
132 self.display = Some(display);
133 self
134 }
135
136 #[must_use]
138 pub fn with_prompt(mut self, prompt: Vec<Prompt>) -> Self {
139 self.prompt = Some(prompt);
140 self
141 }
142
143 #[must_use]
145 pub fn with_max_age(mut self, max_age: NonZeroU32) -> Self {
146 self.max_age = Some(max_age);
147 self
148 }
149
150 #[must_use]
152 pub fn with_ui_locales(mut self, ui_locales: Vec<LanguageTag>) -> Self {
153 self.ui_locales = Some(ui_locales);
154 self
155 }
156
157 #[must_use]
159 pub fn with_id_token_hint(mut self, id_token_hint: String) -> Self {
160 self.id_token_hint = Some(id_token_hint);
161 self
162 }
163
164 #[must_use]
166 pub fn with_login_hint(mut self, login_hint: String) -> Self {
167 self.login_hint = Some(login_hint);
168 self
169 }
170
171 #[must_use]
173 pub fn with_acr_values(mut self, acr_values: HashSet<String>) -> Self {
174 self.acr_values = Some(acr_values);
175 self
176 }
177
178 #[must_use]
180 pub fn with_response_mode(mut self, response_mode: ResponseMode) -> Self {
181 self.response_mode = Some(response_mode);
182 self
183 }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq)]
189pub struct AuthorizationValidationData {
190 pub state: String,
192
193 pub nonce: Option<String>,
197
198 pub redirect_uri: Url,
200
201 pub code_challenge_verifier: Option<String>,
203}
204
205#[derive(Clone, Serialize)]
206struct FullAuthorizationRequest {
207 #[serde(flatten)]
208 inner: AuthorizationRequest,
209
210 #[serde(flatten, skip_serializing_if = "Option::is_none")]
211 pkce: Option<pkce::AuthorizationRequest>,
212}
213
214fn build_authorization_request(
216 authorization_data: AuthorizationRequestData,
217 rng: &mut impl Rng,
218) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> {
219 let AuthorizationRequestData {
220 client_id,
221 scope,
222 redirect_uri,
223 code_challenge_methods_supported,
224 display,
225 prompt,
226 max_age,
227 ui_locales,
228 id_token_hint,
229 login_hint,
230 acr_values,
231 response_mode,
232 } = authorization_data;
233
234 let is_openid = scope.contains(&OPENID);
235
236 let state = Alphanumeric.sample_string(rng, 16);
238
239 let nonce = is_openid.then(|| Alphanumeric.sample_string(rng, 16));
241
242 let (pkce, code_challenge_verifier) = if code_challenge_methods_supported
244 .iter()
245 .any(|methods| methods.contains(&PkceCodeChallengeMethod::S256))
246 {
247 let mut verifier = [0u8; 32];
248 rng.fill(&mut verifier);
249
250 let method = PkceCodeChallengeMethod::S256;
251 let verifier = Base64UrlUnpadded::encode_string(&verifier);
252 let code_challenge = method.compute_challenge(&verifier)?.into();
253
254 let pkce = pkce::AuthorizationRequest {
255 code_challenge_method: method,
256 code_challenge,
257 };
258
259 (Some(pkce), Some(verifier))
260 } else {
261 (None, None)
262 };
263
264 let auth_request = FullAuthorizationRequest {
265 inner: AuthorizationRequest {
266 response_type: OAuthAuthorizationEndpointResponseType::Code.into(),
267 client_id,
268 redirect_uri: Some(redirect_uri.clone()),
269 scope,
270 state: Some(state.clone()),
271 response_mode,
272 nonce: nonce.clone(),
273 display,
274 prompt,
275 max_age,
276 ui_locales,
277 id_token_hint,
278 login_hint,
279 acr_values,
280 request: None,
281 request_uri: None,
282 registration: None,
283 },
284 pkce,
285 };
286
287 let auth_data = AuthorizationValidationData {
288 state,
289 nonce,
290 redirect_uri,
291 code_challenge_verifier,
292 };
293
294 Ok((auth_request, auth_data))
295}
296
297pub fn build_authorization_url(
328 authorization_endpoint: Url,
329 authorization_data: AuthorizationRequestData,
330 rng: &mut impl Rng,
331) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
332 tracing::debug!(
333 scope = ?authorization_data.scope,
334 "Authorizing..."
335 );
336
337 let (authorization_request, validation_data) =
338 build_authorization_request(authorization_data, rng)?;
339
340 let authorization_query = serde_urlencoded::to_string(authorization_request)?;
341
342 let mut authorization_url = authorization_endpoint;
343
344 let mut full_query = authorization_url
346 .query()
347 .map(ToOwned::to_owned)
348 .unwrap_or_default();
349 if !full_query.is_empty() {
350 full_query.push('&');
351 }
352 full_query.push_str(&authorization_query);
353
354 authorization_url.set_query(Some(&full_query));
355
356 Ok((authorization_url, validation_data))
357}
358
359#[allow(clippy::too_many_arguments)]
397#[tracing::instrument(skip_all, fields(token_endpoint))]
398pub async fn access_token_with_authorization_code(
399 http_client: &reqwest::Client,
400 client_credentials: ClientCredentials,
401 token_endpoint: &Url,
402 code: String,
403 validation_data: AuthorizationValidationData,
404 id_token_verification_data: Option<JwtVerificationData<'_>>,
405 now: DateTime<Utc>,
406 rng: &mut impl Rng,
407) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenAuthorizationCodeError> {
408 tracing::debug!("Exchanging authorization code for access token...");
409
410 let token_response = request_access_token(
411 http_client,
412 client_credentials,
413 token_endpoint,
414 AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
415 code: code.clone(),
416 redirect_uri: Some(validation_data.redirect_uri),
417 code_verifier: validation_data.code_challenge_verifier,
418 }),
419 now,
420 rng,
421 )
422 .await?;
423
424 let id_token = if let Some(verification_data) = id_token_verification_data {
425 let signing_alg = verification_data.signing_algorithm;
426
427 let id_token = token_response
428 .id_token
429 .as_deref()
430 .ok_or(IdTokenError::MissingIdToken)?;
431
432 let id_token = verify_id_token(id_token, verification_data, None, now)?;
433
434 let mut claims = id_token.payload().clone();
435
436 claims::AT_HASH
438 .extract_optional_with_options(
439 &mut claims,
440 TokenHash::new(signing_alg, &token_response.access_token),
441 )
442 .map_err(IdTokenError::from)?;
443
444 claims::C_HASH
446 .extract_optional_with_options(&mut claims, TokenHash::new(signing_alg, &code))
447 .map_err(IdTokenError::from)?;
448
449 if let Some(nonce) = validation_data.nonce.as_deref() {
451 claims::NONCE
452 .extract_required_with_options(&mut claims, nonce)
453 .map_err(IdTokenError::from)?;
454 }
455
456 Some(id_token.into_owned())
457 } else {
458 None
459 };
460
461 Ok((token_response, id_token))
462}