use chrono::{DateTime, Utc};
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_jose::jwk::PublicJsonWebKeySet;
use oauth2_types::{oidc::ApplicationType, requests::GrantType};
use rand::RngCore;
use serde::Serialize;
use thiserror::Error;
use ulid::Ulid;
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum JwksOrJwksUri {
Jwks(PublicJsonWebKeySet),
JwksUri(Url),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Client {
pub id: Ulid,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub application_type: Option<ApplicationType>,
pub redirect_uris: Vec<Url>,
pub response_types: Vec<OAuthAuthorizationEndpointResponseType>,
pub grant_types: Vec<GrantType>,
pub contacts: Vec<String>,
pub client_name: Option<String>, pub logo_uri: Option<Url>, pub client_uri: Option<Url>, pub policy_uri: Option<Url>, pub tos_uri: Option<Url>, pub jwks: Option<JwksOrJwksUri>,
pub id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
pub initiate_login_uri: Option<Url>,
}
#[derive(Debug, Error)]
pub enum InvalidRedirectUriError {
#[error("redirect_uri is not allowed for this client")]
NotAllowed,
#[error("multiple redirect_uris registered for this client")]
MultipleRegistered,
#[error("client has no redirect_uri registered")]
NoneRegistered,
}
impl Client {
pub fn resolve_redirect_uri<'a>(
&'a self,
redirect_uri: &'a Option<Url>,
) -> Result<&'a Url, InvalidRedirectUriError> {
match (&self.redirect_uris[..], redirect_uri) {
([], _) => Err(InvalidRedirectUriError::NoneRegistered),
([one], None) => Ok(one),
(_, None) => Err(InvalidRedirectUriError::MultipleRegistered),
(uris, Some(uri)) if uri_matches_one_of(uri, uris) => Ok(uri),
_ => Err(InvalidRedirectUriError::NotAllowed),
}
}
#[doc(hidden)]
pub fn samples(now: DateTime<Utc>, rng: &mut impl RngCore) -> Vec<Client> {
vec![
Self {
id: Ulid::from_datetime_with_source(now.into(), rng),
client_id: "client1".to_owned(),
encrypted_client_secret: None,
application_type: Some(ApplicationType::Web),
redirect_uris: vec![
Url::parse("https://client1.example.com/redirect").unwrap(),
Url::parse("https://client1.example.com/redirect2").unwrap(),
],
response_types: vec![OAuthAuthorizationEndpointResponseType::Code],
grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
contacts: vec!["foo@client1.example.com".to_owned()],
client_name: Some("Client 1".to_owned()),
client_uri: Some(Url::parse("https://client1.example.com").unwrap()),
logo_uri: Some(Url::parse("https://client1.example.com/logo.png").unwrap()),
tos_uri: Some(Url::parse("https://client1.example.com/tos").unwrap()),
policy_uri: Some(Url::parse("https://client1.example.com/policy").unwrap()),
initiate_login_uri: Some(
Url::parse("https://client1.example.com/initiate-login").unwrap(),
),
token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None),
token_endpoint_auth_signing_alg: None,
id_token_signed_response_alg: None,
userinfo_signed_response_alg: None,
jwks: None,
},
Self {
id: Ulid::from_datetime_with_source(now.into(), rng),
client_id: "client2".to_owned(),
encrypted_client_secret: None,
application_type: Some(ApplicationType::Native),
redirect_uris: vec![Url::parse("https://client2.example.com/redirect").unwrap()],
response_types: vec![OAuthAuthorizationEndpointResponseType::Code],
grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
contacts: vec!["foo@client2.example.com".to_owned()],
client_name: None,
client_uri: None,
logo_uri: None,
tos_uri: None,
policy_uri: None,
initiate_login_uri: None,
token_endpoint_auth_method: None,
token_endpoint_auth_signing_alg: None,
id_token_signed_response_alg: None,
userinfo_signed_response_alg: None,
jwks: None,
},
]
}
}
const LOCAL_HOSTS: &[&str] = &["localhost", "127.0.0.1", "[::1]"];
fn uri_matches_one_of(uri: &Url, registered_uris: &[Url]) -> bool {
if LOCAL_HOSTS.contains(&uri.host_str().unwrap_or_default()) {
let mut uri = uri.clone();
if uri.set_port(None).is_ok() && registered_uris.contains(&uri) {
return true;
}
}
registered_uris.contains(uri)
}
#[cfg(test)]
mod tests {
use url::Url;
use super::*;
#[test]
fn test_uri_matches_one_of() {
let registered_uris = &[
Url::parse("http://127.0.0.1").unwrap(),
Url::parse("https://example.org").unwrap(),
];
assert!(uri_matches_one_of(
&Url::parse("https://example.org").unwrap(),
registered_uris
));
assert!(!uri_matches_one_of(
&Url::parse("https://example.org:8080").unwrap(),
registered_uris
));
assert!(uri_matches_one_of(
&Url::parse("http://127.0.0.1").unwrap(),
registered_uris
));
assert!(uri_matches_one_of(
&Url::parse("http://127.0.0.1:8080").unwrap(),
registered_uris
));
assert!(!uri_matches_one_of(
&Url::parse("http://localhost").unwrap(),
registered_uris
));
}
}