use std::{collections::HashMap, error::Error};
use async_trait::async_trait;
use axum::{
extract::{
rejection::{FailedToDeserializeForm, FormRejection},
Form, FromRequest, FromRequestParts,
},
response::{IntoResponse, Response},
};
use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
use mas_data_model::Session;
use mas_storage::{
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
Clock, RepositoryAccess,
};
use serde::{de::DeserializeOwned, Deserialize};
use thiserror::Error;
#[derive(Debug, Deserialize)]
struct AuthorizedForm<F> {
#[serde(default)]
access_token: Option<String>,
#[serde(flatten)]
inner: F,
}
#[derive(Debug)]
enum AccessToken {
Form(String),
Header(String),
None,
}
impl AccessToken {
async fn fetch<E>(
&self,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t,
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
};
let token = repo
.oauth2_access_token()
.find_by_token(token.as_str())
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
Ok((token, session))
}
}
#[derive(Debug)]
pub struct UserAuthorization<F = ()> {
access_token: AccessToken,
form: Option<F>,
}
impl<F: Send> UserAuthorization<F> {
pub async fn protected_form<E>(
self,
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<(Session, F), AuthorizationVerificationError<E>> {
let Some(form) = self.form else {
return Err(AuthorizationVerificationError::MissingForm);
};
let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok((session, form))
}
pub async fn protected<E>(
self,
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<Session, AuthorizationVerificationError<E>> {
let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
if !token.is_used() {
repo.oauth2_access_token().mark_used(clock, token).await?;
}
Ok(session)
}
}
pub enum UserAuthorizationError {
InvalidHeader,
TokenInFormAndHeader,
BadForm(FailedToDeserializeForm),
Internal(Box<dyn Error>),
}
#[derive(Debug, Error)]
pub enum AuthorizationVerificationError<E> {
#[error("missing token")]
MissingToken,
#[error("invalid token")]
InvalidToken,
#[error("missing form")]
MissingForm,
#[error(transparent)]
Internal(#[from] E),
}
enum BearerError {
InvalidRequest,
InvalidToken,
#[allow(dead_code)]
InsufficientScope {
scope: Option<HeaderValue>,
},
}
impl BearerError {
fn error(&self) -> HeaderValue {
match self {
BearerError::InvalidRequest => HeaderValue::from_static("invalid_request"),
BearerError::InvalidToken => HeaderValue::from_static("invalid_token"),
BearerError::InsufficientScope { .. } => HeaderValue::from_static("insufficient_scope"),
}
}
fn params(&self) -> HashMap<&'static str, HeaderValue> {
match self {
BearerError::InsufficientScope { scope: Some(scope) } => {
let mut m = HashMap::new();
m.insert("scope", scope.clone());
m
}
_ => HashMap::new(),
}
}
}
enum WwwAuthenticate {
#[allow(dead_code)]
Basic { realm: HeaderValue },
Bearer {
realm: Option<HeaderValue>,
error: BearerError,
error_description: Option<HeaderValue>,
},
}
impl Header for WwwAuthenticate {
fn name() -> &'static HeaderName {
&WWW_AUTHENTICATE
}
fn decode<'i, I>(_values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i http::HeaderValue>,
{
Err(headers::Error::invalid())
}
fn encode<E: Extend<http::HeaderValue>>(&self, values: &mut E) {
let (scheme, params) = match self {
WwwAuthenticate::Basic { realm } => {
let mut params = HashMap::new();
params.insert("realm", realm.clone());
("Basic", params)
}
WwwAuthenticate::Bearer {
realm,
error,
error_description,
} => {
let mut params = error.params();
params.insert("error", error.error());
if let Some(realm) = realm {
params.insert("realm", realm.clone());
}
if let Some(error_description) = error_description {
params.insert("error_description", error_description.clone());
}
("Bearer", params)
}
};
let params = params.into_iter().map(|(k, v)| format!(" {k}={v:?}"));
let value: String = std::iter::once(scheme.to_owned()).chain(params).collect();
let value = HeaderValue::from_str(&value).unwrap();
values.extend(std::iter::once(value));
}
}
impl IntoResponse for UserAuthorizationError {
fn into_response(self) -> Response {
match self {
Self::BadForm(_) | Self::InvalidHeader | Self::TokenInFormAndHeader => {
let mut headers = HeaderMap::new();
headers.typed_insert(WwwAuthenticate::Bearer {
realm: None,
error: BearerError::InvalidRequest,
error_description: None,
});
(StatusCode::BAD_REQUEST, headers).into_response()
}
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
}
}
impl<E> IntoResponse for AuthorizationVerificationError<E>
where
E: ToString,
{
fn into_response(self) -> Response {
match self {
Self::MissingForm | Self::MissingToken => {
let mut headers = HeaderMap::new();
headers.typed_insert(WwwAuthenticate::Bearer {
realm: None,
error: BearerError::InvalidRequest,
error_description: None,
});
(StatusCode::BAD_REQUEST, headers).into_response()
}
Self::InvalidToken => {
let mut headers = HeaderMap::new();
headers.typed_insert(WwwAuthenticate::Bearer {
realm: None,
error: BearerError::InvalidToken,
error_description: None,
});
(StatusCode::BAD_REQUEST, headers).into_response()
}
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
}
}
#[async_trait]
impl<S, F> FromRequest<S> for UserAuthorization<F>
where
F: DeserializeOwned,
S: Send + Sync,
{
type Rejection = UserAuthorizationError;
async fn from_request(
req: Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
let header =
TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
let token_from_header = match header {
Ok(header) => Some(header.token().to_owned()),
Err(err) => match err.reason() {
TypedHeaderRejectionReason::Missing => None,
_ => return Err(UserAuthorizationError::InvalidHeader),
},
};
let req = Request::from_parts(parts, body);
let (token_from_form, form) =
match Form::<AuthorizedForm<F>>::from_request(req, state).await {
Ok(Form(form)) => (form.access_token, Some(form.inner)),
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
Err(FormRejection::FailedToDeserializeForm(err)) => {
return Err(UserAuthorizationError::BadForm(err))
}
Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))),
};
let access_token = match (token_from_header, token_from_form) {
(Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
(Some(t), None) => AccessToken::Header(t),
(None, Some(t)) => AccessToken::Form(t),
(None, None) => AccessToken::None,
};
Ok(UserAuthorization { access_token, form })
}
}