1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
// Copyright 2024 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::ops::Deref;

use figment::Figment;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::jwk::PublicJsonWebKeySet;
use schemars::JsonSchema;
use serde::{de::Error, Deserialize, Serialize};
use ulid::Ulid;
use url::Url;

use super::ConfigurationSection;

#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum JwksOrJwksUri {
    Jwks(PublicJsonWebKeySet),
    JwksUri(Url),
}

impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
    fn from(jwks: PublicJsonWebKeySet) -> Self {
        Self::Jwks(jwks)
    }
}

/// Authentication method used by clients
#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthMethodConfig {
    /// `none`: No authentication
    None,

    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
    /// authorization credentials
    ClientSecretBasic,

    /// `client_secret_post`: `client_id` and `client_secret` sent in the
    /// request body
    ClientSecretPost,

    /// `client_secret_basic`: a `client_assertion` sent in the request body and
    /// signed using the `client_secret`
    ClientSecretJwt,

    /// `client_secret_basic`: a `client_assertion` sent in the request body and
    /// signed by an asymmetric key
    PrivateKeyJwt,
}

impl std::fmt::Display for ClientAuthMethodConfig {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ClientAuthMethodConfig::None => write!(f, "none"),
            ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
            ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
            ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
            ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
        }
    }
}

/// An OAuth 2.0 client configuration
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ClientConfig {
    /// The client ID
    #[schemars(
        with = "String",
        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
        description = "A ULID as per https://github.com/ulid/spec"
    )]
    pub client_id: Ulid,

    /// Authentication method used for this client
    client_auth_method: ClientAuthMethodConfig,

    /// The client secret, used by the `client_secret_basic`,
    /// `client_secret_post` and `client_secret_jwt` authentication methods
    #[serde(skip_serializing_if = "Option::is_none")]
    pub client_secret: Option<String>,

    /// The JSON Web Key Set (JWKS) used by the `private_key_jwt` authentication
    /// method. Mutually exclusive with `jwks_uri`
    #[serde(skip_serializing_if = "Option::is_none")]
    pub jwks: Option<PublicJsonWebKeySet>,

    /// The URL of the JSON Web Key Set (JWKS) used by the `private_key_jwt`
    /// authentication method. Mutually exclusive with `jwks`
    #[serde(skip_serializing_if = "Option::is_none")]
    pub jwks_uri: Option<Url>,

    /// List of allowed redirect URIs
    #[serde(default, skip_serializing_if = "Vec::is_empty")]
    pub redirect_uris: Vec<Url>,
}

impl ClientConfig {
    fn validate(&self) -> Result<(), figment::error::Error> {
        let auth_method = self.client_auth_method;
        match self.client_auth_method {
            ClientAuthMethodConfig::PrivateKeyJwt => {
                if self.jwks.is_none() && self.jwks_uri.is_none() {
                    let error = figment::error::Error::custom(
                        "jwks or jwks_uri is required for private_key_jwt",
                    );
                    return Err(error.with_path("client_auth_method"));
                }

                if self.jwks.is_some() && self.jwks_uri.is_some() {
                    let error =
                        figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
                    return Err(error.with_path("jwks"));
                }

                if self.client_secret.is_some() {
                    let error = figment::error::Error::custom(
                        "client_secret is not allowed with private_key_jwt",
                    );
                    return Err(error.with_path("client_secret"));
                }
            }

            ClientAuthMethodConfig::ClientSecretPost
            | ClientAuthMethodConfig::ClientSecretBasic
            | ClientAuthMethodConfig::ClientSecretJwt => {
                if self.client_secret.is_none() {
                    let error = figment::error::Error::custom(format!(
                        "client_secret is required for {auth_method}"
                    ));
                    return Err(error.with_path("client_auth_method"));
                }

                if self.jwks.is_some() {
                    let error = figment::error::Error::custom(format!(
                        "jwks is not allowed with {auth_method}"
                    ));
                    return Err(error.with_path("jwks"));
                }

                if self.jwks_uri.is_some() {
                    let error = figment::error::Error::custom(format!(
                        "jwks_uri is not allowed with {auth_method}"
                    ));
                    return Err(error.with_path("jwks_uri"));
                }
            }

            ClientAuthMethodConfig::None => {
                if self.client_secret.is_some() {
                    let error = figment::error::Error::custom(
                        "client_secret is not allowed with none authentication method",
                    );
                    return Err(error.with_path("client_secret"));
                }

                if self.jwks.is_some() {
                    let error = figment::error::Error::custom(
                        "jwks is not allowed with none authentication method",
                    );
                    return Err(error);
                }

                if self.jwks_uri.is_some() {
                    let error = figment::error::Error::custom(
                        "jwks_uri is not allowed with none authentication method",
                    );
                    return Err(error);
                }
            }
        }

        Ok(())
    }

    /// Authentication method used for this client
    #[must_use]
    pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
        match self.client_auth_method {
            ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
            ClientAuthMethodConfig::ClientSecretBasic => {
                OAuthClientAuthenticationMethod::ClientSecretBasic
            }
            ClientAuthMethodConfig::ClientSecretPost => {
                OAuthClientAuthenticationMethod::ClientSecretPost
            }
            ClientAuthMethodConfig::ClientSecretJwt => {
                OAuthClientAuthenticationMethod::ClientSecretJwt
            }
            ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
        }
    }
}

/// List of OAuth 2.0/OIDC clients config
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
#[serde(transparent)]
pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);

impl ClientsConfig {
    /// Returns true if all fields are at their default values
    pub(crate) fn is_default(&self) -> bool {
        self.0.is_empty()
    }
}

impl Deref for ClientsConfig {
    type Target = Vec<ClientConfig>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl IntoIterator for ClientsConfig {
    type Item = ClientConfig;
    type IntoIter = std::vec::IntoIter<ClientConfig>;

    fn into_iter(self) -> Self::IntoIter {
        self.0.into_iter()
    }
}

impl ConfigurationSection for ClientsConfig {
    const PATH: Option<&'static str> = Some("clients");

    fn validate(&self, figment: &Figment) -> Result<(), figment::error::Error> {
        for (index, client) in self.0.iter().enumerate() {
            client.validate().map_err(|mut err| {
                // Save the error location information in the error
                err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
                err.profile = Some(figment::Profile::Default);
                err.path.insert(0, Self::PATH.unwrap().to_owned());
                err.path.insert(1, format!("{index}"));
                err
            })?;
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use std::str::FromStr;

    use figment::{
        providers::{Format, Yaml},
        Figment, Jail,
    };

    use super::*;

    #[test]
    fn load_config() {
        Jail::expect_with(|jail| {
            jail.create_file(
                "config.yaml",
                r#"
                  clients:
                    - client_id: 01GFWR28C4KNE04WG3HKXB7C9R
                      client_auth_method: none
                      redirect_uris:
                        - https://exemple.fr/callback

                    - client_id: 01GFWR32NCQ12B8Z0J8CPXRRB6
                      client_auth_method: client_secret_basic
                      client_secret: hello

                    - client_id: 01GFWR3WHR93Y5HK389H28VHZ9
                      client_auth_method: client_secret_post
                      client_secret: hello

                    - client_id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
                      client_auth_method: client_secret_jwt
                      client_secret: hello

                    - client_id: 01GFWR4BNFDCC4QDG6AMSP1VRR
                      client_auth_method: private_key_jwt
                      jwks:
                        keys:
                        - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
                          kty: "RSA"
                          alg: "RS256"
                          use: "sig"
                          e: "AQAB"
                          n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"

                        - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
                          kty: "RSA"
                          alg: "RS256"
                          use: "sig"
                          e: "AQAB"
                          n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
                "#,
            )?;

            let config = Figment::new()
                .merge(Yaml::file("config.yaml"))
                .extract_inner::<ClientsConfig>("clients")?;

            assert_eq!(config.0.len(), 5);

            assert_eq!(
                config.0[0].client_id,
                Ulid::from_str("01GFWR28C4KNE04WG3HKXB7C9R").unwrap()
            );
            assert_eq!(
                config.0[0].redirect_uris,
                vec!["https://exemple.fr/callback".parse().unwrap()]
            );

            assert_eq!(
                config.0[1].client_id,
                Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
            );
            assert_eq!(config.0[1].redirect_uris, Vec::new());

            Ok(())
        });
    }
}