1#![allow(clippy::module_name_repetitions)]
12
13use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
14
15use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
16use serde_with::{DeserializeFromStr, SerializeDisplay};
17use thiserror::Error;
18
19#[derive(Debug, Error, Clone, PartialEq, Eq)]
21#[error("invalid response type")]
22pub struct InvalidResponseType;
23
24#[derive(
32 Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, SerializeDisplay, DeserializeFromStr,
33)]
34#[non_exhaustive]
35pub enum ResponseTypeToken {
36 Code,
38
39 IdToken,
41
42 Token,
44
45 Unknown(String),
47}
48
49impl core::fmt::Display for ResponseTypeToken {
50 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51 match self {
52 ResponseTypeToken::Code => f.write_str("code"),
53 ResponseTypeToken::IdToken => f.write_str("id_token"),
54 ResponseTypeToken::Token => f.write_str("token"),
55 ResponseTypeToken::Unknown(s) => f.write_str(s),
56 }
57 }
58}
59
60impl core::str::FromStr for ResponseTypeToken {
61 type Err = core::convert::Infallible;
62
63 fn from_str(s: &str) -> Result<Self, Self::Err> {
64 match s {
65 "code" => Ok(Self::Code),
66 "id_token" => Ok(Self::IdToken),
67 "token" => Ok(Self::Token),
68 s => Ok(Self::Unknown(s.to_owned())),
69 }
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr, PartialOrd, Ord)]
82pub struct ResponseType(BTreeSet<ResponseTypeToken>);
83
84impl std::ops::Deref for ResponseType {
85 type Target = BTreeSet<ResponseTypeToken>;
86
87 fn deref(&self) -> &Self::Target {
88 &self.0
89 }
90}
91
92impl ResponseType {
93 #[must_use]
95 pub fn has_code(&self) -> bool {
96 self.0.contains(&ResponseTypeToken::Code)
97 }
98
99 #[must_use]
101 pub fn has_id_token(&self) -> bool {
102 self.0.contains(&ResponseTypeToken::IdToken)
103 }
104
105 #[must_use]
107 pub fn has_token(&self) -> bool {
108 self.0.contains(&ResponseTypeToken::Token)
109 }
110}
111
112impl FromStr for ResponseType {
113 type Err = InvalidResponseType;
114
115 fn from_str(s: &str) -> Result<Self, Self::Err> {
116 let s = s.trim();
117
118 if s.is_empty() {
119 Err(InvalidResponseType)
120 } else if s == "none" {
121 Ok(Self(BTreeSet::new()))
122 } else {
123 s.split_ascii_whitespace()
124 .map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
125 .collect::<Result<_, _>>()
126 }
127 }
128}
129
130impl fmt::Display for ResponseType {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 let mut iter = self.iter();
133
134 if let Some(first) = iter.next() {
136 first.fmt(f)?;
137 } else {
138 write!(f, "none")?;
140 return Ok(());
141 }
142
143 for item in iter {
145 write!(f, " {item}")?;
146 }
147
148 Ok(())
149 }
150}
151
152impl FromIterator<ResponseTypeToken> for ResponseType {
153 fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
154 Self(BTreeSet::from_iter(iter))
155 }
156}
157
158impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
159 fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
160 match response_type {
161 OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
162 OAuthAuthorizationEndpointResponseType::CodeIdToken => {
163 Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
164 }
165 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
166 [
167 ResponseTypeToken::Code,
168 ResponseTypeToken::IdToken,
169 ResponseTypeToken::Token,
170 ]
171 .into(),
172 ),
173 OAuthAuthorizationEndpointResponseType::CodeToken => {
174 Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
175 }
176 OAuthAuthorizationEndpointResponseType::IdToken => {
177 Self([ResponseTypeToken::IdToken].into())
178 }
179 OAuthAuthorizationEndpointResponseType::IdTokenToken => {
180 Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
181 }
182 OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
183 OAuthAuthorizationEndpointResponseType::Token => {
184 Self([ResponseTypeToken::Token].into())
185 }
186 }
187 }
188}
189
190impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
191 type Error = InvalidResponseType;
192
193 fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
194 if response_type
195 .iter()
196 .any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
197 {
198 return Err(InvalidResponseType);
199 }
200
201 let tokens = response_type.iter().collect::<Vec<_>>();
202 let res = match *tokens {
203 [ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
204 [ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
205 [ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
206 [ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
207 OAuthAuthorizationEndpointResponseType::CodeIdToken
208 }
209 [ResponseTypeToken::Code, ResponseTypeToken::Token] => {
210 OAuthAuthorizationEndpointResponseType::CodeToken
211 }
212 [ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
213 OAuthAuthorizationEndpointResponseType::IdTokenToken
214 }
215 [
216 ResponseTypeToken::Code,
217 ResponseTypeToken::IdToken,
218 ResponseTypeToken::Token,
219 ] => OAuthAuthorizationEndpointResponseType::CodeIdTokenToken,
220 _ => OAuthAuthorizationEndpointResponseType::None,
221 };
222
223 Ok(res)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn deserialize_response_type_token() {
233 assert_eq!(
234 serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
235 ResponseTypeToken::Code
236 );
237 assert_eq!(
238 serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
239 ResponseTypeToken::IdToken
240 );
241 assert_eq!(
242 serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
243 ResponseTypeToken::Token
244 );
245 assert_eq!(
246 serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
247 ResponseTypeToken::Unknown("something_unsupported".to_owned())
248 );
249 }
250
251 #[test]
252 fn serialize_response_type_token() {
253 assert_eq!(
254 serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
255 "\"code\""
256 );
257 assert_eq!(
258 serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
259 "\"id_token\""
260 );
261 assert_eq!(
262 serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
263 "\"token\""
264 );
265 assert_eq!(
266 serde_json::to_string(&ResponseTypeToken::Unknown(
267 "something_unsupported".to_owned()
268 ))
269 .unwrap(),
270 "\"something_unsupported\""
271 );
272 }
273
274 #[test]
275 fn deserialize_response_type() {
276 serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
277
278 let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
279 let mut iter = res_type.iter();
280 assert_eq!(iter.next(), None);
281 assert_eq!(
282 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
283 OAuthAuthorizationEndpointResponseType::None
284 );
285
286 let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
287 let mut iter = res_type.iter();
288 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
289 assert_eq!(iter.next(), None);
290 assert_eq!(
291 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
292 OAuthAuthorizationEndpointResponseType::Code
293 );
294
295 let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
296 let mut iter = res_type.iter();
297 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
298 assert_eq!(iter.next(), None);
299 assert_eq!(
300 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
301 OAuthAuthorizationEndpointResponseType::Code
302 );
303
304 let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
305 let mut iter = res_type.iter();
306 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
307 assert_eq!(iter.next(), None);
308 assert_eq!(
309 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
310 OAuthAuthorizationEndpointResponseType::IdToken
311 );
312
313 let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
314 let mut iter = res_type.iter();
315 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
316 assert_eq!(iter.next(), None);
317 assert_eq!(
318 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
319 OAuthAuthorizationEndpointResponseType::Token
320 );
321
322 let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
323 let mut iter = res_type.iter();
324 assert_eq!(
325 iter.next(),
326 Some(&ResponseTypeToken::Unknown(
327 "something_unsupported".to_owned()
328 ))
329 );
330 assert_eq!(iter.next(), None);
331 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
332
333 let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
334 let mut iter = res_type.iter();
335 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
336 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
337 assert_eq!(iter.next(), None);
338 assert_eq!(
339 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
340 OAuthAuthorizationEndpointResponseType::CodeIdToken
341 );
342
343 let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
344 let mut iter = res_type.iter();
345 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
346 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
347 assert_eq!(iter.next(), None);
348 assert_eq!(
349 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
350 OAuthAuthorizationEndpointResponseType::CodeToken
351 );
352
353 let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
354 let mut iter = res_type.iter();
355 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
356 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
357 assert_eq!(iter.next(), None);
358 assert_eq!(
359 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
360 OAuthAuthorizationEndpointResponseType::IdTokenToken
361 );
362
363 let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
364 let mut iter = res_type.iter();
365 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
366 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
367 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
368 assert_eq!(iter.next(), None);
369 assert_eq!(
370 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
371 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
372 );
373
374 let res_type =
375 serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
376 .unwrap();
377 let mut iter = res_type.iter();
378 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
379 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
380 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
381 assert_eq!(
382 iter.next(),
383 Some(&ResponseTypeToken::Unknown(
384 "something_unsupported".to_owned()
385 ))
386 );
387 assert_eq!(iter.next(), None);
388 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
389
390 let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
392 let mut iter = res_type.iter();
393 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
394 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
395 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
396 assert_eq!(iter.next(), None);
397 assert_eq!(
398 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
399 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
400 );
401
402 let res_type =
403 serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
404 let mut iter = res_type.iter();
405 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
406 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
407 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
408 assert_eq!(iter.next(), None);
409 assert_eq!(
410 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
411 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
412 );
413 }
414
415 #[test]
416 fn serialize_response_type() {
417 assert_eq!(
418 serde_json::to_string(&ResponseType::from(
419 OAuthAuthorizationEndpointResponseType::None
420 ))
421 .unwrap(),
422 "\"none\""
423 );
424 assert_eq!(
425 serde_json::to_string(&ResponseType::from(
426 OAuthAuthorizationEndpointResponseType::Code
427 ))
428 .unwrap(),
429 "\"code\""
430 );
431 assert_eq!(
432 serde_json::to_string(&ResponseType::from(
433 OAuthAuthorizationEndpointResponseType::IdToken
434 ))
435 .unwrap(),
436 "\"id_token\""
437 );
438 assert_eq!(
439 serde_json::to_string(&ResponseType::from(
440 OAuthAuthorizationEndpointResponseType::CodeIdToken
441 ))
442 .unwrap(),
443 "\"code id_token\""
444 );
445 assert_eq!(
446 serde_json::to_string(&ResponseType::from(
447 OAuthAuthorizationEndpointResponseType::CodeToken
448 ))
449 .unwrap(),
450 "\"code token\""
451 );
452 assert_eq!(
453 serde_json::to_string(&ResponseType::from(
454 OAuthAuthorizationEndpointResponseType::IdTokenToken
455 ))
456 .unwrap(),
457 "\"id_token token\""
458 );
459 assert_eq!(
460 serde_json::to_string(&ResponseType::from(
461 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
462 ))
463 .unwrap(),
464 "\"code id_token token\""
465 );
466
467 assert_eq!(
468 serde_json::to_string(
469 &[
470 ResponseTypeToken::Unknown("something_unsupported".to_owned()),
471 ResponseTypeToken::Code
472 ]
473 .into_iter()
474 .collect::<ResponseType>()
475 )
476 .unwrap(),
477 "\"code something_unsupported\""
478 );
479
480 let res = [
482 ResponseTypeToken::IdToken,
483 ResponseTypeToken::Token,
484 ResponseTypeToken::Code,
485 ]
486 .into_iter()
487 .collect::<ResponseType>();
488 assert_eq!(
489 serde_json::to_string(&res).unwrap(),
490 "\"code id_token token\""
491 );
492
493 let res = [
494 ResponseTypeToken::Code,
495 ResponseTypeToken::Token,
496 ResponseTypeToken::IdToken,
497 ]
498 .into_iter()
499 .collect::<ResponseType>();
500 assert_eq!(
501 serde_json::to_string(&res).unwrap(),
502 "\"code id_token token\""
503 );
504 }
505}