1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

//! The error types used in this crate.

use async_trait::async_trait;
use mas_jose::{
    claims::ClaimError,
    jwa::InvalidAlgorithm,
    jwt::{JwtDecodeError, JwtSignatureError, NoKeyWorked},
};
use oauth2_types::{oidc::ProviderMetadataVerificationError, pkce::CodeChallengeError};
use serde::Deserialize;
use thiserror::Error;

/// All possible errors when using this crate.
#[derive(Debug, Error)]
#[error(transparent)]
pub enum Error {
    /// An error occurred fetching provider metadata.
    Discovery(#[from] DiscoveryError),

    /// An error occurred fetching the provider JWKS.
    Jwks(#[from] JwksError),

    /// An error occurred building the authorization URL.
    Authorization(#[from] AuthorizationError),

    /// An error occurred exchanging an authorization code for an access token.
    TokenAuthorizationCode(#[from] TokenAuthorizationCodeError),

    /// An error occurred requesting an access token with client credentials.
    TokenClientCredentials(#[from] TokenRequestError),

    /// An error occurred refreshing an access token.
    TokenRefresh(#[from] TokenRefreshError),

    /// An error occurred requesting user info.
    UserInfo(#[from] UserInfoError),
}

/// All possible errors when fetching provider metadata.
#[derive(Debug, Error)]
#[error("Fetching provider metadata failed")]
pub enum DiscoveryError {
    /// An error occurred building the request's URL.
    IntoUrl(#[from] url::ParseError),

    /// The server returned an HTTP error status code.
    Http(#[from] reqwest::Error),

    /// An error occurred validating the metadata.
    Validation(#[from] ProviderMetadataVerificationError),

    /// Discovery is disabled for this provider.
    #[error("Discovery is disabled for this provider")]
    Disabled,
}

/// All possible errors when authorizing the client.
#[derive(Debug, Error)]
#[error("Building the authorization URL failed")]
pub enum AuthorizationError {
    /// An error occurred constructing the PKCE code challenge.
    Pkce(#[from] CodeChallengeError),

    /// An error occurred serializing the request.
    UrlEncoded(#[from] serde_urlencoded::ser::Error),
}

/// All possible errors when requesting an access token.
#[derive(Debug, Error)]
#[error("Request to the token endpoint failed")]
pub enum TokenRequestError {
    /// The HTTP client returned an error.
    Http(#[from] reqwest::Error),

    /// The server returned an error
    OAuth2(#[from] OAuth2Error),

    /// Error while injecting the client credentials into the request.
    Credentials(#[from] CredentialsError),
}

/// All possible errors when exchanging a code for an access token.
#[derive(Debug, Error)]
pub enum TokenAuthorizationCodeError {
    /// An error occurred requesting the access token.
    #[error(transparent)]
    Token(#[from] TokenRequestError),

    /// An error occurred validating the ID Token.
    #[error("Verifying the 'id_token' returned by the provider failed")]
    IdToken(#[from] IdTokenError),
}

/// All possible errors when refreshing an access token.
#[derive(Debug, Error)]
pub enum TokenRefreshError {
    /// An error occurred requesting the access token.
    #[error(transparent)]
    Token(#[from] TokenRequestError),

    /// An error occurred validating the ID Token.
    #[error("Verifying the 'id_token' returned by the provider failed")]
    IdToken(#[from] IdTokenError),
}

/// All possible errors when requesting user info.
#[derive(Debug, Error)]
pub enum UserInfoError {
    /// The content-type header is missing from the response.
    #[error("missing response content-type")]
    MissingResponseContentType,

    /// The content-type is not valid.
    #[error("invalid response content-type")]
    InvalidResponseContentTypeValue,

    /// The content-type is not the one that was expected.
    #[error("unexpected response content-type {got:?}, expected {expected:?}")]
    UnexpectedResponseContentType {
        /// The expected content-type.
        expected: String,
        /// The returned content-type.
        got: String,
    },

    /// An error occurred verifying the Id Token.
    #[error("Verifying the 'id_token' returned by the provider failed")]
    IdToken(#[from] IdTokenError),

    /// An error occurred sending the request.
    #[error(transparent)]
    Http(#[from] reqwest::Error),

    /// The server returned an error
    #[error(transparent)]
    OAuth2(#[from] OAuth2Error),
}

/// All possible errors when requesting a JWKS.
#[derive(Debug, Error)]
#[error("Failed to fetch JWKS")]
pub enum JwksError {
    /// An error occurred sending the request.
    Http(#[from] reqwest::Error),
}

/// All possible errors when verifying a JWT.
#[derive(Debug, Error)]
pub enum JwtVerificationError {
    /// An error occured decoding the JWT.
    #[error(transparent)]
    JwtDecode(#[from] JwtDecodeError),

    /// No key worked for verifying the JWT's signature.
    #[error(transparent)]
    JwtSignature(#[from] NoKeyWorked),

    /// An error occurred extracting a claim.
    #[error(transparent)]
    Claim(#[from] ClaimError),

    /// The algorithm used for signing the JWT is not the one that was
    /// requested.
    #[error("wrong signature alg")]
    WrongSignatureAlg,
}

/// All possible errors when verifying an ID token.
#[derive(Debug, Error)]
pub enum IdTokenError {
    /// No ID Token was found in the response although one was expected.
    #[error("ID token is missing")]
    MissingIdToken,

    /// The ID Token from the latest Authorization was not provided although
    /// this request expects to be verified against one.
    #[error("Authorization ID token is missing")]
    MissingAuthIdToken,

    #[error(transparent)]
    /// An error occurred validating the ID Token's signature and basic claims.
    Jwt(#[from] JwtVerificationError),

    #[error(transparent)]
    /// An error occurred extracting a claim.
    Claim(#[from] ClaimError),

    /// The subject identifier returned by the issuer is not the same as the one
    /// we got before.
    #[error("wrong subject identifier")]
    WrongSubjectIdentifier,

    /// The authentication time returned by the issuer is not the same as the
    /// one we got before.
    #[error("wrong authentication time")]
    WrongAuthTime,
}

/// All errors that can occur when adding client credentials to the request.
#[derive(Debug, Error)]
pub enum CredentialsError {
    /// Trying to use an unsupported authentication method.
    #[error("unsupported authentication method")]
    UnsupportedMethod,

    /// When authenticationg with `private_key_jwt`, no private key was found
    /// for the given algorithm.
    #[error("no private key was found for the given algorithm")]
    NoPrivateKeyFound,

    /// The signing algorithm is invalid for this authentication method.
    #[error("invalid algorithm: {0}")]
    InvalidSigningAlgorithm(#[from] InvalidAlgorithm),

    /// An error occurred when building the claims of the JWT.
    #[error(transparent)]
    JwtClaims(#[from] ClaimError),

    /// The key found cannot be used with the algorithm.
    #[error("Wrong algorithm for key")]
    JwtWrongAlgorithm,

    /// An error occurred when signing the JWT.
    #[error(transparent)]
    JwtSignature(#[from] JwtSignatureError),
}

#[derive(Debug, Deserialize)]
struct OAuth2ErrorResponse {
    error: String,
    error_description: Option<String>,
    error_uri: Option<String>,
}

impl std::fmt::Display for OAuth2ErrorResponse {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?}", self.error)?;

        if let Some(error_uri) = &self.error_uri {
            write!(f, " (See {error_uri})")?;
        }

        if let Some(error_description) = &self.error_description {
            write!(f, ": {error_description}")?;
        }

        Ok(())
    }
}

/// An error returned by the OAuth 2.0 provider
#[derive(Debug, Error)]
pub struct OAuth2Error {
    error: Option<OAuth2ErrorResponse>,

    #[source]
    inner: reqwest::Error,
}

impl std::fmt::Display for OAuth2Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if let Some(error) = &self.error {
            write!(
                f,
                "Request to the provider failed with the following error: {error}"
            )
        } else {
            write!(f, "Request to the provider failed")
        }
    }
}

impl From<reqwest::Error> for OAuth2Error {
    fn from(inner: reqwest::Error) -> Self {
        Self { error: None, inner }
    }
}

/// An extension trait to deal with error responses from the OAuth 2.0 provider
#[async_trait]
pub(crate) trait ResponseExt {
    async fn error_from_oauth2_error_response(self) -> Result<Self, OAuth2Error>
    where
        Self: Sized;
}

#[async_trait]
impl ResponseExt for reqwest::Response {
    async fn error_from_oauth2_error_response(self) -> Result<Self, OAuth2Error> {
        let Err(inner) = self.error_for_status_ref() else {
            return Ok(self);
        };

        let error: OAuth2ErrorResponse = self.json().await?;

        Err(OAuth2Error {
            error: Some(error),
            inner,
        })
    }
}