1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

//! Requests for using [Refresh Tokens].
//!
//! [Refresh Tokens]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens

use chrono::{DateTime, Utc};
use mas_jose::claims::{self, TokenHash};
use oauth2_types::{
    requests::{AccessTokenRequest, AccessTokenResponse, RefreshTokenGrant},
    scope::Scope,
};
use rand::Rng;
use url::Url;

use super::jose::JwtVerificationData;
use crate::{
    error::{IdTokenError, TokenRefreshError},
    requests::{jose::verify_id_token, token::request_access_token},
    types::{client_credentials::ClientCredentials, IdToken},
};

/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a
/// token with a new scope.
///
/// # Arguments
///
/// * `http_client` - The reqwest client to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
///   client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `refresh_token` - The token used to refresh the access token returned at
///   the Token endpoint.
///
/// * `scope` - The scope of the access token. The requested scope must not
///   include any scope not originally granted to the access token, and if
///   omitted is treated as equal to the scope originally granted by the issuer.
///
/// * `id_token_verification_data` - The data required to verify the ID Token in
///   the response.
///
///   The signing algorithm corresponds to the `id_token_signed_response_alg`
/// field in the client metadata.
///
///   If it is not provided, the ID Token won't be verified.
///
/// * `auth_id_token` - If an ID Token is expected in the response, the ID token
///   that was returned from the latest authorization request.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// verification of the ID Token fails.
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn refresh_access_token(
    http_client: &reqwest::Client,
    client_credentials: ClientCredentials,
    token_endpoint: &Url,
    refresh_token: String,
    scope: Option<Scope>,
    id_token_verification_data: Option<JwtVerificationData<'_>>,
    auth_id_token: Option<&IdToken<'_>>,
    now: DateTime<Utc>,
    rng: &mut impl Rng,
) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenRefreshError> {
    tracing::debug!("Refreshing access token…");

    let token_response = request_access_token(
        http_client,
        client_credentials,
        token_endpoint,
        AccessTokenRequest::RefreshToken(RefreshTokenGrant {
            refresh_token,
            scope,
        }),
        now,
        rng,
    )
    .await?;

    let id_token = if let Some((verification_data, id_token)) =
        id_token_verification_data.zip(token_response.id_token.as_ref())
    {
        let auth_id_token = auth_id_token.ok_or(IdTokenError::MissingAuthIdToken)?;
        let signing_alg = verification_data.signing_algorithm;

        let id_token = verify_id_token(id_token, verification_data, Some(auth_id_token), now)?;

        let mut claims = id_token.payload().clone();

        // Access token hash must match.
        claims::AT_HASH
            .extract_optional_with_options(
                &mut claims,
                TokenHash::new(signing_alg, &token_response.access_token),
            )
            .map_err(IdTokenError::from)?;

        Some(id_token.into_owned())
    } else {
        None
    };

    Ok((token_response, id_token))
}