oauth2_types/
response_type.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! [Response types] in the OpenID Connect specification.
8//!
9//! [Response types]: https://openid.net/specs/openid-connect-core-1_0.html#Authentication
10
11#![allow(clippy::module_name_repetitions)]
12
13use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
14
15use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
16use serde_with::{DeserializeFromStr, SerializeDisplay};
17use thiserror::Error;
18
19/// An error encountered when trying to parse an invalid [`ResponseType`].
20#[derive(Debug, Error, Clone, PartialEq, Eq)]
21#[error("invalid response type")]
22pub struct InvalidResponseType;
23
24/// The accepted tokens in a [`ResponseType`].
25///
26/// `none` is not in this enum because it is represented by an empty
27/// [`ResponseType`].
28///
29/// This type also accepts unknown tokens that can be constructed via it's
30/// `FromStr` implementation or used via its `Display` implementation.
31#[derive(
32    Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, SerializeDisplay, DeserializeFromStr,
33)]
34#[non_exhaustive]
35pub enum ResponseTypeToken {
36    /// `code`
37    Code,
38
39    /// `id_token`
40    IdToken,
41
42    /// `token`
43    Token,
44
45    /// Unknown token.
46    Unknown(String),
47}
48
49impl core::fmt::Display for ResponseTypeToken {
50    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51        match self {
52            ResponseTypeToken::Code => f.write_str("code"),
53            ResponseTypeToken::IdToken => f.write_str("id_token"),
54            ResponseTypeToken::Token => f.write_str("token"),
55            ResponseTypeToken::Unknown(s) => f.write_str(s),
56        }
57    }
58}
59
60impl core::str::FromStr for ResponseTypeToken {
61    type Err = core::convert::Infallible;
62
63    fn from_str(s: &str) -> Result<Self, Self::Err> {
64        match s {
65            "code" => Ok(Self::Code),
66            "id_token" => Ok(Self::IdToken),
67            "token" => Ok(Self::Token),
68            s => Ok(Self::Unknown(s.to_owned())),
69        }
70    }
71}
72
73/// An [OAuth 2.0 `response_type` value] that the client can use
74/// at the [authorization endpoint].
75///
76/// It is recommended to construct this type from an
77/// [`OAuthAuthorizationEndpointResponseType`].
78///
79/// [OAuth 2.0 `response_type` value]: https://www.rfc-editor.org/rfc/rfc7591#page-9
80/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
81#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr, PartialOrd, Ord)]
82pub struct ResponseType(BTreeSet<ResponseTypeToken>);
83
84impl std::ops::Deref for ResponseType {
85    type Target = BTreeSet<ResponseTypeToken>;
86
87    fn deref(&self) -> &Self::Target {
88        &self.0
89    }
90}
91
92impl ResponseType {
93    /// Whether this response type requests a code.
94    #[must_use]
95    pub fn has_code(&self) -> bool {
96        self.0.contains(&ResponseTypeToken::Code)
97    }
98
99    /// Whether this response type requests an ID token.
100    #[must_use]
101    pub fn has_id_token(&self) -> bool {
102        self.0.contains(&ResponseTypeToken::IdToken)
103    }
104
105    /// Whether this response type requests a token.
106    #[must_use]
107    pub fn has_token(&self) -> bool {
108        self.0.contains(&ResponseTypeToken::Token)
109    }
110}
111
112impl FromStr for ResponseType {
113    type Err = InvalidResponseType;
114
115    fn from_str(s: &str) -> Result<Self, Self::Err> {
116        let s = s.trim();
117
118        if s.is_empty() {
119            Err(InvalidResponseType)
120        } else if s == "none" {
121            Ok(Self(BTreeSet::new()))
122        } else {
123            s.split_ascii_whitespace()
124                .map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
125                .collect::<Result<_, _>>()
126        }
127    }
128}
129
130impl fmt::Display for ResponseType {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        let mut iter = self.iter();
133
134        // First item shouldn't have a leading space
135        if let Some(first) = iter.next() {
136            first.fmt(f)?;
137        } else {
138            // If the whole iterator is empty, write 'none' instead
139            write!(f, "none")?;
140            return Ok(());
141        }
142
143        // Write the other items with a leading space
144        for item in iter {
145            write!(f, " {item}")?;
146        }
147
148        Ok(())
149    }
150}
151
152impl FromIterator<ResponseTypeToken> for ResponseType {
153    fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
154        Self(BTreeSet::from_iter(iter))
155    }
156}
157
158impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
159    fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
160        match response_type {
161            OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
162            OAuthAuthorizationEndpointResponseType::CodeIdToken => {
163                Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
164            }
165            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
166                [
167                    ResponseTypeToken::Code,
168                    ResponseTypeToken::IdToken,
169                    ResponseTypeToken::Token,
170                ]
171                .into(),
172            ),
173            OAuthAuthorizationEndpointResponseType::CodeToken => {
174                Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
175            }
176            OAuthAuthorizationEndpointResponseType::IdToken => {
177                Self([ResponseTypeToken::IdToken].into())
178            }
179            OAuthAuthorizationEndpointResponseType::IdTokenToken => {
180                Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
181            }
182            OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
183            OAuthAuthorizationEndpointResponseType::Token => {
184                Self([ResponseTypeToken::Token].into())
185            }
186        }
187    }
188}
189
190impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
191    type Error = InvalidResponseType;
192
193    fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
194        if response_type
195            .iter()
196            .any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
197        {
198            return Err(InvalidResponseType);
199        }
200
201        let tokens = response_type.iter().collect::<Vec<_>>();
202        let res = match *tokens {
203            [ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
204            [ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
205            [ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
206            [ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
207                OAuthAuthorizationEndpointResponseType::CodeIdToken
208            }
209            [ResponseTypeToken::Code, ResponseTypeToken::Token] => {
210                OAuthAuthorizationEndpointResponseType::CodeToken
211            }
212            [ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
213                OAuthAuthorizationEndpointResponseType::IdTokenToken
214            }
215            [
216                ResponseTypeToken::Code,
217                ResponseTypeToken::IdToken,
218                ResponseTypeToken::Token,
219            ] => OAuthAuthorizationEndpointResponseType::CodeIdTokenToken,
220            _ => OAuthAuthorizationEndpointResponseType::None,
221        };
222
223        Ok(res)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn deserialize_response_type_token() {
233        assert_eq!(
234            serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
235            ResponseTypeToken::Code
236        );
237        assert_eq!(
238            serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
239            ResponseTypeToken::IdToken
240        );
241        assert_eq!(
242            serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
243            ResponseTypeToken::Token
244        );
245        assert_eq!(
246            serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
247            ResponseTypeToken::Unknown("something_unsupported".to_owned())
248        );
249    }
250
251    #[test]
252    fn serialize_response_type_token() {
253        assert_eq!(
254            serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
255            "\"code\""
256        );
257        assert_eq!(
258            serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
259            "\"id_token\""
260        );
261        assert_eq!(
262            serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
263            "\"token\""
264        );
265        assert_eq!(
266            serde_json::to_string(&ResponseTypeToken::Unknown(
267                "something_unsupported".to_owned()
268            ))
269            .unwrap(),
270            "\"something_unsupported\""
271        );
272    }
273
274    #[test]
275    fn deserialize_response_type() {
276        serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
277
278        let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
279        let mut iter = res_type.iter();
280        assert_eq!(iter.next(), None);
281        assert_eq!(
282            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
283            OAuthAuthorizationEndpointResponseType::None
284        );
285
286        let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
287        let mut iter = res_type.iter();
288        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
289        assert_eq!(iter.next(), None);
290        assert_eq!(
291            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
292            OAuthAuthorizationEndpointResponseType::Code
293        );
294
295        let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
296        let mut iter = res_type.iter();
297        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
298        assert_eq!(iter.next(), None);
299        assert_eq!(
300            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
301            OAuthAuthorizationEndpointResponseType::Code
302        );
303
304        let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
305        let mut iter = res_type.iter();
306        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
307        assert_eq!(iter.next(), None);
308        assert_eq!(
309            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
310            OAuthAuthorizationEndpointResponseType::IdToken
311        );
312
313        let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
314        let mut iter = res_type.iter();
315        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
316        assert_eq!(iter.next(), None);
317        assert_eq!(
318            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
319            OAuthAuthorizationEndpointResponseType::Token
320        );
321
322        let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
323        let mut iter = res_type.iter();
324        assert_eq!(
325            iter.next(),
326            Some(&ResponseTypeToken::Unknown(
327                "something_unsupported".to_owned()
328            ))
329        );
330        assert_eq!(iter.next(), None);
331        OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
332
333        let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
334        let mut iter = res_type.iter();
335        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
336        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
337        assert_eq!(iter.next(), None);
338        assert_eq!(
339            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
340            OAuthAuthorizationEndpointResponseType::CodeIdToken
341        );
342
343        let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
344        let mut iter = res_type.iter();
345        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
346        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
347        assert_eq!(iter.next(), None);
348        assert_eq!(
349            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
350            OAuthAuthorizationEndpointResponseType::CodeToken
351        );
352
353        let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
354        let mut iter = res_type.iter();
355        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
356        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
357        assert_eq!(iter.next(), None);
358        assert_eq!(
359            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
360            OAuthAuthorizationEndpointResponseType::IdTokenToken
361        );
362
363        let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
364        let mut iter = res_type.iter();
365        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
366        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
367        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
368        assert_eq!(iter.next(), None);
369        assert_eq!(
370            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
371            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
372        );
373
374        let res_type =
375            serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
376                .unwrap();
377        let mut iter = res_type.iter();
378        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
379        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
380        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
381        assert_eq!(
382            iter.next(),
383            Some(&ResponseTypeToken::Unknown(
384                "something_unsupported".to_owned()
385            ))
386        );
387        assert_eq!(iter.next(), None);
388        OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
389
390        // Order doesn't matter
391        let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
392        let mut iter = res_type.iter();
393        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
394        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
395        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
396        assert_eq!(iter.next(), None);
397        assert_eq!(
398            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
399            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
400        );
401
402        let res_type =
403            serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
404        let mut iter = res_type.iter();
405        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
406        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
407        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
408        assert_eq!(iter.next(), None);
409        assert_eq!(
410            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
411            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
412        );
413    }
414
415    #[test]
416    fn serialize_response_type() {
417        assert_eq!(
418            serde_json::to_string(&ResponseType::from(
419                OAuthAuthorizationEndpointResponseType::None
420            ))
421            .unwrap(),
422            "\"none\""
423        );
424        assert_eq!(
425            serde_json::to_string(&ResponseType::from(
426                OAuthAuthorizationEndpointResponseType::Code
427            ))
428            .unwrap(),
429            "\"code\""
430        );
431        assert_eq!(
432            serde_json::to_string(&ResponseType::from(
433                OAuthAuthorizationEndpointResponseType::IdToken
434            ))
435            .unwrap(),
436            "\"id_token\""
437        );
438        assert_eq!(
439            serde_json::to_string(&ResponseType::from(
440                OAuthAuthorizationEndpointResponseType::CodeIdToken
441            ))
442            .unwrap(),
443            "\"code id_token\""
444        );
445        assert_eq!(
446            serde_json::to_string(&ResponseType::from(
447                OAuthAuthorizationEndpointResponseType::CodeToken
448            ))
449            .unwrap(),
450            "\"code token\""
451        );
452        assert_eq!(
453            serde_json::to_string(&ResponseType::from(
454                OAuthAuthorizationEndpointResponseType::IdTokenToken
455            ))
456            .unwrap(),
457            "\"id_token token\""
458        );
459        assert_eq!(
460            serde_json::to_string(&ResponseType::from(
461                OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
462            ))
463            .unwrap(),
464            "\"code id_token token\""
465        );
466
467        assert_eq!(
468            serde_json::to_string(
469                &[
470                    ResponseTypeToken::Unknown("something_unsupported".to_owned()),
471                    ResponseTypeToken::Code
472                ]
473                .into_iter()
474                .collect::<ResponseType>()
475            )
476            .unwrap(),
477            "\"code something_unsupported\""
478        );
479
480        // Order doesn't matter.
481        let res = [
482            ResponseTypeToken::IdToken,
483            ResponseTypeToken::Token,
484            ResponseTypeToken::Code,
485        ]
486        .into_iter()
487        .collect::<ResponseType>();
488        assert_eq!(
489            serde_json::to_string(&res).unwrap(),
490            "\"code id_token token\""
491        );
492
493        let res = [
494            ResponseTypeToken::Code,
495            ResponseTypeToken::Token,
496            ResponseTypeToken::IdToken,
497        ]
498        .into_iter()
499        .collect::<ResponseType>();
500        assert_eq!(
501            serde_json::to_string(&res).unwrap(),
502            "\"code id_token token\""
503        );
504    }
505}