pub mod model;
use mas_data_model::{AuthorizationGrant, Client, DeviceCodeGrant, User};
use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
use opa_wasm::{
wasmtime::{Config, Engine, Module, OptLevel, Store},
Runtime,
};
use serde::Serialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
use self::model::{AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput};
pub use self::model::{Code as ViolationCode, EvaluationResult, Violation};
use crate::model::GrantType;
#[derive(Debug, Error)]
pub enum LoadError {
#[error("failed to read module")]
Read(#[from] tokio::io::Error),
#[error("failed to create WASM engine")]
Engine(#[source] anyhow::Error),
#[error("module compilation task crashed")]
CompilationTask(#[from] tokio::task::JoinError),
#[error("failed to compile WASM module")]
Compilation(#[source] anyhow::Error),
#[error("failed to instantiate a test instance")]
Instantiate(#[source] InstantiateError),
}
#[derive(Debug, Error)]
pub enum InstantiateError {
#[error("failed to create WASM runtime")]
Runtime(#[source] anyhow::Error),
#[error("missing entrypoint {entrypoint}")]
MissingEntrypoint { entrypoint: String },
#[error("failed to load policy data")]
LoadData(#[source] anyhow::Error),
}
#[derive(Debug, Clone)]
pub struct Entrypoints {
pub register: String,
pub client_registration: String,
pub authorization_grant: String,
pub email: String,
}
impl Entrypoints {
fn all(&self) -> [&str; 4] {
[
self.register.as_str(),
self.client_registration.as_str(),
self.authorization_grant.as_str(),
self.email.as_str(),
]
}
}
#[derive(Serialize, Debug)]
pub struct Data {
server_name: String,
#[serde(flatten)]
rest: Option<serde_json::Value>,
}
impl Data {
#[must_use]
pub fn new(server_name: String) -> Self {
Self {
server_name,
rest: None,
}
}
#[must_use]
pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
self.rest = Some(rest);
self
}
}
pub struct PolicyFactory {
engine: Engine,
module: Module,
data: Data,
entrypoints: Entrypoints,
}
impl PolicyFactory {
#[tracing::instrument(name = "policy.load", skip(source), err)]
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: Data,
entrypoints: Entrypoints,
) -> Result<Self, LoadError> {
let mut config = Config::default();
config.async_support(true);
config.cranelift_opt_level(OptLevel::SpeedAndSize);
let engine = Engine::new(&config).map_err(LoadError::Engine)?;
let mut buf = Vec::new();
source.read_to_end(&mut buf).await?;
let (engine, module) = tokio::task::spawn_blocking(move || {
let module = Module::new(&engine, buf)?;
anyhow::Ok((engine, module))
})
.await?
.map_err(LoadError::Compilation)?;
let factory = Self {
engine,
module,
data,
entrypoints,
};
factory
.instantiate()
.await
.map_err(LoadError::Instantiate)?;
Ok(factory)
}
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
let mut store = Store::new(&self.engine, ());
let runtime = Runtime::new(&mut store, &self.module)
.await
.map_err(InstantiateError::Runtime)?;
let policy_entrypoints = runtime.entrypoints();
for e in self.entrypoints.all() {
if !policy_entrypoints.contains(e) {
return Err(InstantiateError::MissingEntrypoint {
entrypoint: e.to_owned(),
});
}
}
let instance = runtime
.with_data(&mut store, &self.data)
.await
.map_err(InstantiateError::LoadData)?;
Ok(Policy {
store,
instance,
entrypoints: self.entrypoints.clone(),
})
}
}
pub struct Policy {
store: Store<()>,
instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
entrypoints: Entrypoints,
}
#[derive(Debug, Error)]
#[error("failed to evaluate policy")]
pub enum EvaluationError {
Serialization(#[from] serde_json::Error),
Evaluation(#[from] anyhow::Error),
}
impl Policy {
#[tracing::instrument(
name = "policy.evaluate_email",
skip_all,
fields(
input.email = email,
),
err,
)]
pub async fn evaluate_email(
&mut self,
email: &str,
) -> Result<EvaluationResult, EvaluationError> {
let input = EmailInput { email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.email, &input)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.register",
skip_all,
fields(
input.registration_method = "password",
input.user.username = username,
input.user.email = email,
),
err,
)]
pub async fn evaluate_register(
&mut self,
username: &str,
email: &str,
) -> Result<EvaluationResult, EvaluationError> {
let input = RegisterInput::Password { username, email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.register, &input)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.upstream_oauth_register",
skip_all,
fields(
input.registration_method = "password",
input.user.username = username,
input.user.email = email,
),
err,
)]
pub async fn evaluate_upstream_oauth_register(
&mut self,
username: &str,
email: Option<&str>,
) -> Result<EvaluationResult, EvaluationError> {
let input = RegisterInput::UpstreamOAuth2 { username, email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.register, &input)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self))]
pub async fn evaluate_client_registration(
&mut self,
client_metadata: &VerifiedClientMetadata,
) -> Result<EvaluationResult, EvaluationError> {
let input = ClientRegistrationInput { client_metadata };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.entrypoints.client_registration,
&input,
)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.authorization_grant",
skip_all,
fields(
input.authorization_grant.id = %authorization_grant.id,
input.scope = %authorization_grant.scope,
input.client.id = %client.id,
input.user.id = %user.id,
),
err,
)]
pub async fn evaluate_authorization_grant(
&mut self,
authorization_grant: &AuthorizationGrant,
client: &Client,
user: &User,
) -> Result<EvaluationResult, EvaluationError> {
let input = AuthorizationGrantInput {
user: Some(user),
client,
scope: &authorization_grant.scope,
grant_type: GrantType::AuthorizationCode,
};
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.entrypoints.authorization_grant,
&input,
)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.client_credentials_grant",
skip_all,
fields(
input.scope = %scope,
input.client.id = %client.id,
),
err,
)]
pub async fn evaluate_client_credentials_grant(
&mut self,
scope: &Scope,
client: &Client,
) -> Result<EvaluationResult, EvaluationError> {
let input = AuthorizationGrantInput {
user: None,
client,
scope,
grant_type: GrantType::ClientCredentials,
};
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.entrypoints.authorization_grant,
&input,
)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.device_code_grant",
skip_all,
fields(
input.device_code_grant.id = %device_code_grant.id,
input.scope = %device_code_grant.scope,
input.client.id = %client.id,
input.user.id = %user.id,
),
err,
)]
pub async fn evaluate_device_code_grant(
&mut self,
device_code_grant: &DeviceCodeGrant,
client: &Client,
user: &User,
) -> Result<EvaluationResult, EvaluationError> {
let input = AuthorizationGrantInput {
user: Some(user),
client,
scope: &device_code_grant.scope,
grant_type: GrantType::DeviceCode,
};
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.entrypoints.authorization_grant,
&input,
)
.await?;
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_register() {
let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
"allowed_domains": ["element.io", "*.element.io"],
"banned_domains": ["staging.element.io"],
}));
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("policies")
.join("policy.wasm");
let file = tokio::fs::File::open(path).await.unwrap();
let entrypoints = Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
};
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
let mut policy = factory.instantiate().await.unwrap();
let res = policy
.evaluate_register("hello", "hello@example.com")
.await
.unwrap();
assert!(!res.valid());
let res = policy
.evaluate_register("hello", "hello@foo.element.io")
.await
.unwrap();
assert!(res.valid());
let res = policy
.evaluate_register("hello", "hello@staging.element.io")
.await
.unwrap();
assert!(!res.valid());
}
}