mas_config/sections/
clients.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::ops::Deref;
8
9use mas_iana::oauth::OAuthClientAuthenticationMethod;
10use mas_jose::jwk::PublicJsonWebKeySet;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize, de::Error};
13use serde_with::serde_as;
14use ulid::Ulid;
15use url::Url;
16
17use super::{ClientSecret, ClientSecretRaw, ConfigurationSection};
18
19/// Authentication method used by clients
20#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
21#[serde(rename_all = "snake_case")]
22pub enum ClientAuthMethodConfig {
23    /// `none`: No authentication
24    None,
25
26    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
27    /// authorization credentials
28    ClientSecretBasic,
29
30    /// `client_secret_post`: `client_id` and `client_secret` sent in the
31    /// request body
32    ClientSecretPost,
33
34    /// `client_secret_basic`: a `client_assertion` sent in the request body and
35    /// signed using the `client_secret`
36    ClientSecretJwt,
37
38    /// `client_secret_basic`: a `client_assertion` sent in the request body and
39    /// signed by an asymmetric key
40    PrivateKeyJwt,
41}
42
43impl std::fmt::Display for ClientAuthMethodConfig {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            ClientAuthMethodConfig::None => write!(f, "none"),
47            ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
48            ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
49            ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
50            ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
51        }
52    }
53}
54
55/// An OAuth 2.0 client configuration
56#[serde_as]
57#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
58pub struct ClientConfig {
59    /// The client ID
60    #[schemars(
61        with = "String",
62        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
63        description = "A ULID as per https://github.com/ulid/spec"
64    )]
65    pub client_id: Ulid,
66
67    /// Authentication method used for this client
68    client_auth_method: ClientAuthMethodConfig,
69
70    /// Name of the `OAuth2` client
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub client_name: Option<String>,
73
74    /// The client secret, used by the `client_secret_basic`,
75    /// `client_secret_post` and `client_secret_jwt` authentication methods
76    #[schemars(with = "ClientSecretRaw")]
77    #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
78    #[serde(flatten)]
79    pub client_secret: Option<ClientSecret>,
80
81    /// The JSON Web Key Set (JWKS) used by the `private_key_jwt` authentication
82    /// method. Mutually exclusive with `jwks_uri`
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub jwks: Option<PublicJsonWebKeySet>,
85
86    /// The URL of the JSON Web Key Set (JWKS) used by the `private_key_jwt`
87    /// authentication method. Mutually exclusive with `jwks`
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub jwks_uri: Option<Url>,
90
91    /// List of allowed redirect URIs
92    #[serde(default, skip_serializing_if = "Vec::is_empty")]
93    pub redirect_uris: Vec<Url>,
94}
95
96impl ClientConfig {
97    fn validate(&self) -> Result<(), Box<figment::error::Error>> {
98        let auth_method = self.client_auth_method;
99        match self.client_auth_method {
100            ClientAuthMethodConfig::PrivateKeyJwt => {
101                if self.jwks.is_none() && self.jwks_uri.is_none() {
102                    let error = figment::error::Error::custom(
103                        "jwks or jwks_uri is required for private_key_jwt",
104                    );
105                    return Err(Box::new(error.with_path("client_auth_method")));
106                }
107
108                if self.jwks.is_some() && self.jwks_uri.is_some() {
109                    let error =
110                        figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
111                    return Err(Box::new(error.with_path("jwks")));
112                }
113
114                if self.client_secret.is_some() {
115                    let error = figment::error::Error::custom(
116                        "client_secret is not allowed with private_key_jwt",
117                    );
118                    return Err(Box::new(error.with_path("client_secret")));
119                }
120            }
121
122            ClientAuthMethodConfig::ClientSecretPost
123            | ClientAuthMethodConfig::ClientSecretBasic
124            | ClientAuthMethodConfig::ClientSecretJwt => {
125                if self.client_secret.is_none() {
126                    let error = figment::error::Error::custom(format!(
127                        "client_secret is required for {auth_method}"
128                    ));
129                    return Err(Box::new(error.with_path("client_auth_method")));
130                }
131
132                if self.jwks.is_some() {
133                    let error = figment::error::Error::custom(format!(
134                        "jwks is not allowed with {auth_method}"
135                    ));
136                    return Err(Box::new(error.with_path("jwks")));
137                }
138
139                if self.jwks_uri.is_some() {
140                    let error = figment::error::Error::custom(format!(
141                        "jwks_uri is not allowed with {auth_method}"
142                    ));
143                    return Err(Box::new(error.with_path("jwks_uri")));
144                }
145            }
146
147            ClientAuthMethodConfig::None => {
148                if self.client_secret.is_some() {
149                    let error = figment::error::Error::custom(
150                        "client_secret is not allowed with none authentication method",
151                    );
152                    return Err(Box::new(error.with_path("client_secret")));
153                }
154
155                if self.jwks.is_some() {
156                    let error = figment::error::Error::custom(
157                        "jwks is not allowed with none authentication method",
158                    );
159                    return Err(Box::new(error));
160                }
161
162                if self.jwks_uri.is_some() {
163                    let error = figment::error::Error::custom(
164                        "jwks_uri is not allowed with none authentication method",
165                    );
166                    return Err(Box::new(error));
167                }
168            }
169        }
170
171        Ok(())
172    }
173
174    /// Authentication method used for this client
175    #[must_use]
176    pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
177        match self.client_auth_method {
178            ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
179            ClientAuthMethodConfig::ClientSecretBasic => {
180                OAuthClientAuthenticationMethod::ClientSecretBasic
181            }
182            ClientAuthMethodConfig::ClientSecretPost => {
183                OAuthClientAuthenticationMethod::ClientSecretPost
184            }
185            ClientAuthMethodConfig::ClientSecretJwt => {
186                OAuthClientAuthenticationMethod::ClientSecretJwt
187            }
188            ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
189        }
190    }
191
192    /// Returns the client secret.
193    ///
194    /// If `client_secret_file` was given, the secret is read from that file.
195    ///
196    /// # Errors
197    ///
198    /// Returns an error when the client secret could not be read from file.
199    pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
200        Ok(match &self.client_secret {
201            Some(client_secret) => Some(client_secret.value().await?),
202            None => None,
203        })
204    }
205}
206
207/// List of OAuth 2.0/OIDC clients config
208#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
209#[serde(transparent)]
210pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
211
212impl ClientsConfig {
213    /// Returns true if all fields are at their default values
214    pub(crate) fn is_default(&self) -> bool {
215        self.0.is_empty()
216    }
217}
218
219impl Deref for ClientsConfig {
220    type Target = Vec<ClientConfig>;
221
222    fn deref(&self) -> &Self::Target {
223        &self.0
224    }
225}
226
227impl IntoIterator for ClientsConfig {
228    type Item = ClientConfig;
229    type IntoIter = std::vec::IntoIter<ClientConfig>;
230
231    fn into_iter(self) -> Self::IntoIter {
232        self.0.into_iter()
233    }
234}
235
236impl ConfigurationSection for ClientsConfig {
237    const PATH: Option<&'static str> = Some("clients");
238
239    fn validate(
240        &self,
241        figment: &figment::Figment,
242    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
243        for (index, client) in self.0.iter().enumerate() {
244            client.validate().map_err(|mut err| {
245                // Save the error location information in the error
246                err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
247                err.profile = Some(figment::Profile::Default);
248                err.path.insert(0, Self::PATH.unwrap().to_owned());
249                err.path.insert(1, format!("{index}"));
250                err
251            })?;
252        }
253
254        Ok(())
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use std::str::FromStr;
261
262    use figment::{
263        Figment, Jail,
264        providers::{Format, Yaml},
265    };
266    use tokio::{runtime::Handle, task};
267
268    use super::*;
269
270    #[tokio::test]
271    async fn load_config() {
272        task::spawn_blocking(|| {
273            Jail::expect_with(|jail| {
274                jail.create_file(
275                    "config.yaml",
276                    r#"
277                      clients:
278                        - client_id: 01GFWR28C4KNE04WG3HKXB7C9R
279                          client_auth_method: none
280                          redirect_uris:
281                            - https://exemple.fr/callback
282
283                        - client_id: 01GFWR32NCQ12B8Z0J8CPXRRB6
284                          client_auth_method: client_secret_basic
285                          client_secret_file: secret
286
287                        - client_id: 01GFWR3WHR93Y5HK389H28VHZ9
288                          client_auth_method: client_secret_post
289                          client_secret: c1!3n753c237
290
291                        - client_id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
292                          client_auth_method: client_secret_jwt
293                          client_secret_file: secret
294
295                        - client_id: 01GFWR4BNFDCC4QDG6AMSP1VRR
296                          client_auth_method: private_key_jwt
297                          jwks:
298                            keys:
299                            - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
300                              kty: "RSA"
301                              alg: "RS256"
302                              use: "sig"
303                              e: "AQAB"
304                              n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
305
306                            - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
307                              kty: "RSA"
308                              alg: "RS256"
309                              use: "sig"
310                              e: "AQAB"
311                              n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
312                    "#,
313                )?;
314                jail.create_file("secret", r"c1!3n753c237")?;
315
316                let config = Figment::new()
317                    .merge(Yaml::file("config.yaml"))
318                    .extract_inner::<ClientsConfig>("clients")?;
319
320                assert_eq!(config.0.len(), 5);
321
322                assert_eq!(
323                    config.0[0].client_id,
324                    Ulid::from_str("01GFWR28C4KNE04WG3HKXB7C9R").unwrap()
325                );
326                assert_eq!(
327                    config.0[0].redirect_uris,
328                    vec!["https://exemple.fr/callback".parse().unwrap()]
329                );
330
331                assert_eq!(
332                    config.0[1].client_id,
333                    Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
334                );
335                assert_eq!(config.0[1].redirect_uris, Vec::new());
336
337                assert!(config.0[0].client_secret.is_none());
338                assert!(matches!(config.0[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
339                assert!(matches!(config.0[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
340                assert!(matches!(config.0[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
341                assert!(config.0[4].client_secret.is_none());
342
343                Handle::current().block_on(async move {
344                    assert_eq!(config.0[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
345                    assert_eq!(config.0[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
346                    assert_eq!(config.0[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
347                });
348
349                Ok(())
350            });
351        }).await.unwrap();
352    }
353}