1pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::Ulid;
13use opa_wasm::{
14 Runtime,
15 wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use thiserror::Error;
18use tokio::io::{AsyncRead, AsyncReadExt};
19
20pub use self::model::{
21 AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
22 EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
23};
24
25#[derive(Debug, Error)]
26pub enum LoadError {
27 #[error("failed to read module")]
28 Read(#[from] tokio::io::Error),
29
30 #[error("failed to create WASM engine")]
31 Engine(#[source] anyhow::Error),
32
33 #[error("module compilation task crashed")]
34 CompilationTask(#[from] tokio::task::JoinError),
35
36 #[error("failed to compile WASM module")]
37 Compilation(#[source] anyhow::Error),
38
39 #[error("invalid policy data")]
40 InvalidData(#[source] anyhow::Error),
41
42 #[error("failed to instantiate a test instance")]
43 Instantiate(#[source] InstantiateError),
44}
45
46impl LoadError {
47 #[doc(hidden)]
50 #[must_use]
51 pub fn invalid_data_example() -> Self {
52 Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
53 }
54}
55
56#[derive(Debug, Error)]
57pub enum InstantiateError {
58 #[error("failed to create WASM runtime")]
59 Runtime(#[source] anyhow::Error),
60
61 #[error("missing entrypoint {entrypoint}")]
62 MissingEntrypoint { entrypoint: String },
63
64 #[error("failed to load policy data")]
65 LoadData(#[source] anyhow::Error),
66}
67
68#[derive(Debug, Clone)]
70pub struct Entrypoints {
71 pub register: String,
72 pub client_registration: String,
73 pub authorization_grant: String,
74 pub email: String,
75}
76
77impl Entrypoints {
78 fn all(&self) -> [&str; 4] {
79 [
80 self.register.as_str(),
81 self.client_registration.as_str(),
82 self.authorization_grant.as_str(),
83 self.email.as_str(),
84 ]
85 }
86}
87
88#[derive(Debug)]
89pub struct Data {
90 server_name: String,
91
92 rest: Option<serde_json::Value>,
93}
94
95impl Data {
96 #[must_use]
97 pub fn new(server_name: String) -> Self {
98 Self {
99 server_name,
100 rest: None,
101 }
102 }
103
104 #[must_use]
105 pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
106 self.rest = Some(rest);
107 self
108 }
109
110 fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
111 let base = serde_json::json!({
112 "server_name": self.server_name,
113 });
114
115 if let Some(rest) = &self.rest {
116 merge_data(base, rest.clone())
117 } else {
118 Ok(base)
119 }
120 }
121}
122
123fn value_kind(value: &serde_json::Value) -> &'static str {
124 match value {
125 serde_json::Value::Object(_) => "object",
126 serde_json::Value::Array(_) => "array",
127 serde_json::Value::String(_) => "string",
128 serde_json::Value::Number(_) => "number",
129 serde_json::Value::Bool(_) => "boolean",
130 serde_json::Value::Null => "null",
131 }
132}
133
134fn merge_data(
135 mut left: serde_json::Value,
136 right: serde_json::Value,
137) -> Result<serde_json::Value, anyhow::Error> {
138 merge_data_rec(&mut left, right)?;
139 Ok(left)
140}
141
142fn merge_data_rec(
143 left: &mut serde_json::Value,
144 right: serde_json::Value,
145) -> Result<(), anyhow::Error> {
146 match (left, right) {
147 (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
148 for (key, value) in right {
149 if let Some(left_value) = left.get_mut(&key) {
150 merge_data_rec(left_value, value)?;
151 } else {
152 left.insert(key, value);
153 }
154 }
155 }
156 (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
157 left.extend(right);
158 }
159 (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
161 *left = right;
162 }
163 (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
164 *left = right;
165 }
166 (serde_json::Value::String(left), serde_json::Value::String(right)) => {
167 *left = right;
168 }
169
170 (left, right) if left.is_null() => *left = right,
172
173 (left, right) if right.is_null() => *left = right,
175
176 (left, right) => anyhow::bail!(
177 "Cannot merge a {} into a {}",
178 value_kind(&right),
179 value_kind(left),
180 ),
181 }
182
183 Ok(())
184}
185
186struct DynamicData {
187 version: Option<Ulid>,
188 merged: serde_json::Value,
189}
190
191pub struct PolicyFactory {
192 engine: Engine,
193 module: Module,
194 data: Data,
195 dynamic_data: ArcSwap<DynamicData>,
196 entrypoints: Entrypoints,
197}
198
199impl PolicyFactory {
200 #[tracing::instrument(name = "policy.load", skip(source))]
206 pub async fn load(
207 mut source: impl AsyncRead + std::marker::Unpin,
208 data: Data,
209 entrypoints: Entrypoints,
210 ) -> Result<Self, LoadError> {
211 let mut config = Config::default();
212 config.async_support(true);
213 config.cranelift_opt_level(OptLevel::SpeedAndSize);
214
215 let engine = Engine::new(&config).map_err(LoadError::Engine)?;
216
217 let mut buf = Vec::new();
219 source.read_to_end(&mut buf).await?;
220 let (engine, module) = tokio::task::spawn_blocking(move || {
222 let module = Module::new(&engine, buf)?;
223 anyhow::Ok((engine, module))
224 })
225 .await?
226 .map_err(LoadError::Compilation)?;
227
228 let merged = data.to_value().map_err(LoadError::InvalidData)?;
229 let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
230 version: None,
231 merged,
232 }));
233
234 let factory = Self {
235 engine,
236 module,
237 data,
238 dynamic_data,
239 entrypoints,
240 };
241
242 factory
244 .instantiate()
245 .await
246 .map_err(LoadError::Instantiate)?;
247
248 Ok(factory)
249 }
250
251 pub async fn set_dynamic_data(
264 &self,
265 dynamic_data: mas_data_model::PolicyData,
266 ) -> Result<bool, LoadError> {
267 if self.dynamic_data.load().version == Some(dynamic_data.id) {
270 return Ok(false);
272 }
273
274 let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
275 let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
276
277 self.instantiate_with_data(&merged)
279 .await
280 .map_err(LoadError::Instantiate)?;
281
282 self.dynamic_data.store(Arc::new(DynamicData {
284 version: Some(dynamic_data.id),
285 merged,
286 }));
287
288 Ok(true)
289 }
290
291 #[tracing::instrument(name = "policy.instantiate", skip_all)]
298 pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
299 let data = self.dynamic_data.load();
300 self.instantiate_with_data(&data.merged).await
301 }
302
303 async fn instantiate_with_data(
304 &self,
305 data: &serde_json::Value,
306 ) -> Result<Policy, InstantiateError> {
307 let mut store = Store::new(&self.engine, ());
308 let runtime = Runtime::new(&mut store, &self.module)
309 .await
310 .map_err(InstantiateError::Runtime)?;
311
312 let policy_entrypoints = runtime.entrypoints();
314
315 for e in self.entrypoints.all() {
316 if !policy_entrypoints.contains(e) {
317 return Err(InstantiateError::MissingEntrypoint {
318 entrypoint: e.to_owned(),
319 });
320 }
321 }
322
323 let instance = runtime
324 .with_data(&mut store, data)
325 .await
326 .map_err(InstantiateError::LoadData)?;
327
328 Ok(Policy {
329 store,
330 instance,
331 entrypoints: self.entrypoints.clone(),
332 })
333 }
334}
335
336pub struct Policy {
337 store: Store<()>,
338 instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
339 entrypoints: Entrypoints,
340}
341
342#[derive(Debug, Error)]
343#[error("failed to evaluate policy")]
344pub enum EvaluationError {
345 Serialization(#[from] serde_json::Error),
346 Evaluation(#[from] anyhow::Error),
347}
348
349impl Policy {
350 #[tracing::instrument(
356 name = "policy.evaluate_email",
357 skip_all,
358 fields(
359 %input.email,
360 ),
361 )]
362 pub async fn evaluate_email(
363 &mut self,
364 input: EmailInput<'_>,
365 ) -> Result<EvaluationResult, EvaluationError> {
366 let [res]: [EvaluationResult; 1] = self
367 .instance
368 .evaluate(&mut self.store, &self.entrypoints.email, &input)
369 .await?;
370
371 Ok(res)
372 }
373
374 #[tracing::instrument(
380 name = "policy.evaluate.register",
381 skip_all,
382 fields(
383 ?input.registration_method,
384 input.username = input.username,
385 input.email = input.email,
386 ),
387 )]
388 pub async fn evaluate_register(
389 &mut self,
390 input: RegisterInput<'_>,
391 ) -> Result<EvaluationResult, EvaluationError> {
392 let [res]: [EvaluationResult; 1] = self
393 .instance
394 .evaluate(&mut self.store, &self.entrypoints.register, &input)
395 .await?;
396
397 Ok(res)
398 }
399
400 #[tracing::instrument(skip(self))]
406 pub async fn evaluate_client_registration(
407 &mut self,
408 input: ClientRegistrationInput<'_>,
409 ) -> Result<EvaluationResult, EvaluationError> {
410 let [res]: [EvaluationResult; 1] = self
411 .instance
412 .evaluate(
413 &mut self.store,
414 &self.entrypoints.client_registration,
415 &input,
416 )
417 .await?;
418
419 Ok(res)
420 }
421
422 #[tracing::instrument(
428 name = "policy.evaluate.authorization_grant",
429 skip_all,
430 fields(
431 %input.scope,
432 %input.client.id,
433 ),
434 )]
435 pub async fn evaluate_authorization_grant(
436 &mut self,
437 input: AuthorizationGrantInput<'_>,
438 ) -> Result<EvaluationResult, EvaluationError> {
439 let [res]: [EvaluationResult; 1] = self
440 .instance
441 .evaluate(
442 &mut self.store,
443 &self.entrypoints.authorization_grant,
444 &input,
445 )
446 .await?;
447
448 Ok(res)
449 }
450}
451
452#[cfg(test)]
453mod tests {
454
455 use std::time::SystemTime;
456
457 use super::*;
458
459 #[tokio::test]
460 async fn test_register() {
461 let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
462 "allowed_domains": ["element.io", "*.element.io"],
463 "banned_domains": ["staging.element.io"],
464 }));
465
466 #[allow(clippy::disallowed_types)]
467 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
468 .join("..")
469 .join("..")
470 .join("policies")
471 .join("policy.wasm");
472
473 let file = tokio::fs::File::open(path).await.unwrap();
474
475 let entrypoints = Entrypoints {
476 register: "register/violation".to_owned(),
477 client_registration: "client_registration/violation".to_owned(),
478 authorization_grant: "authorization_grant/violation".to_owned(),
479 email: "email/violation".to_owned(),
480 };
481
482 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
483
484 let mut policy = factory.instantiate().await.unwrap();
485
486 let res = policy
487 .evaluate_register(RegisterInput {
488 registration_method: RegistrationMethod::Password,
489 username: "hello",
490 email: Some("hello@example.com"),
491 requester: Requester {
492 ip_address: None,
493 user_agent: None,
494 },
495 })
496 .await
497 .unwrap();
498 assert!(!res.valid());
499
500 let res = policy
501 .evaluate_register(RegisterInput {
502 registration_method: RegistrationMethod::Password,
503 username: "hello",
504 email: Some("hello@foo.element.io"),
505 requester: Requester {
506 ip_address: None,
507 user_agent: None,
508 },
509 })
510 .await
511 .unwrap();
512 assert!(res.valid());
513
514 let res = policy
515 .evaluate_register(RegisterInput {
516 registration_method: RegistrationMethod::Password,
517 username: "hello",
518 email: Some("hello@staging.element.io"),
519 requester: Requester {
520 ip_address: None,
521 user_agent: None,
522 },
523 })
524 .await
525 .unwrap();
526 assert!(!res.valid());
527 }
528
529 #[tokio::test]
530 async fn test_dynamic_data() {
531 let data = Data::new("example.com".to_owned());
532
533 #[allow(clippy::disallowed_types)]
534 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
535 .join("..")
536 .join("..")
537 .join("policies")
538 .join("policy.wasm");
539
540 let file = tokio::fs::File::open(path).await.unwrap();
541
542 let entrypoints = Entrypoints {
543 register: "register/violation".to_owned(),
544 client_registration: "client_registration/violation".to_owned(),
545 authorization_grant: "authorization_grant/violation".to_owned(),
546 email: "email/violation".to_owned(),
547 };
548
549 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
550
551 let mut policy = factory.instantiate().await.unwrap();
552
553 let res = policy
554 .evaluate_register(RegisterInput {
555 registration_method: RegistrationMethod::Password,
556 username: "hello",
557 email: Some("hello@example.com"),
558 requester: Requester {
559 ip_address: None,
560 user_agent: None,
561 },
562 })
563 .await
564 .unwrap();
565 assert!(res.valid());
566
567 factory
569 .set_dynamic_data(mas_data_model::PolicyData {
570 id: Ulid::nil(),
571 created_at: SystemTime::now().into(),
572 data: serde_json::json!({
573 "emails": {
574 "banned_addresses": {
575 "substrings": ["hello"]
576 }
577 }
578 }),
579 })
580 .await
581 .unwrap();
582 let mut policy = factory.instantiate().await.unwrap();
583 let res = policy
584 .evaluate_register(RegisterInput {
585 registration_method: RegistrationMethod::Password,
586 username: "hello",
587 email: Some("hello@example.com"),
588 requester: Requester {
589 ip_address: None,
590 user_agent: None,
591 },
592 })
593 .await
594 .unwrap();
595 assert!(!res.valid());
596 }
597
598 #[tokio::test]
599 async fn test_big_dynamic_data() {
600 let data = Data::new("example.com".to_owned());
601
602 #[allow(clippy::disallowed_types)]
603 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
604 .join("..")
605 .join("..")
606 .join("policies")
607 .join("policy.wasm");
608
609 let file = tokio::fs::File::open(path).await.unwrap();
610
611 let entrypoints = Entrypoints {
612 register: "register/violation".to_owned(),
613 client_registration: "client_registration/violation".to_owned(),
614 authorization_grant: "authorization_grant/violation".to_owned(),
615 email: "email/violation".to_owned(),
616 };
617
618 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
619
620 let data: Vec<String> = (0..(1024 * 1024 / 8))
623 .map(|i| format!("{:05}", i % 100_000))
624 .collect();
625 let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
626 factory
627 .set_dynamic_data(mas_data_model::PolicyData {
628 id: Ulid::nil(),
629 created_at: SystemTime::now().into(),
630 data: json,
631 })
632 .await
633 .unwrap();
634
635 let mut policy = factory.instantiate().await.unwrap();
638 let res = policy
639 .evaluate_register(RegisterInput {
640 registration_method: RegistrationMethod::Password,
641 username: "hello",
642 email: Some("12345@example.com"),
643 requester: Requester {
644 ip_address: None,
645 user_agent: None,
646 },
647 })
648 .await
649 .unwrap();
650 assert!(!res.valid());
651 }
652
653 #[test]
654 fn test_merge() {
655 use serde_json::json as j;
656
657 let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
659 assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
660
661 let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
663 assert_eq!(res, j!({"hello": "john"}));
664
665 let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
666 assert_eq!(res, j!({"hello": false}));
667
668 let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
669 assert_eq!(res, j!({"hello": 42}));
670
671 merge_data(j!({"hello": "world"}), j!({"hello": 123}))
673 .expect_err("Can't merge different types");
674
675 let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
677 assert_eq!(res, j!({"hello": ["world", "john"]}));
678
679 let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
681 assert_eq!(res, j!({"hello": null}));
682
683 let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
685 assert_eq!(res, j!({"hello": "world"}));
686
687 let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
689 assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
690 }
691}