mas_policy/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::Ulid;
13use opa_wasm::{
14    Runtime,
15    wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use thiserror::Error;
18use tokio::io::{AsyncRead, AsyncReadExt};
19
20pub use self::model::{
21    AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
22    EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
23};
24
25#[derive(Debug, Error)]
26pub enum LoadError {
27    #[error("failed to read module")]
28    Read(#[from] tokio::io::Error),
29
30    #[error("failed to create WASM engine")]
31    Engine(#[source] anyhow::Error),
32
33    #[error("module compilation task crashed")]
34    CompilationTask(#[from] tokio::task::JoinError),
35
36    #[error("failed to compile WASM module")]
37    Compilation(#[source] anyhow::Error),
38
39    #[error("invalid policy data")]
40    InvalidData(#[source] anyhow::Error),
41
42    #[error("failed to instantiate a test instance")]
43    Instantiate(#[source] InstantiateError),
44}
45
46impl LoadError {
47    /// Creates an example of an invalid data error, used for API response
48    /// documentation
49    #[doc(hidden)]
50    #[must_use]
51    pub fn invalid_data_example() -> Self {
52        Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
53    }
54}
55
56#[derive(Debug, Error)]
57pub enum InstantiateError {
58    #[error("failed to create WASM runtime")]
59    Runtime(#[source] anyhow::Error),
60
61    #[error("missing entrypoint {entrypoint}")]
62    MissingEntrypoint { entrypoint: String },
63
64    #[error("failed to load policy data")]
65    LoadData(#[source] anyhow::Error),
66}
67
68/// Holds the entrypoint of each policy
69#[derive(Debug, Clone)]
70pub struct Entrypoints {
71    pub register: String,
72    pub client_registration: String,
73    pub authorization_grant: String,
74    pub email: String,
75}
76
77impl Entrypoints {
78    fn all(&self) -> [&str; 4] {
79        [
80            self.register.as_str(),
81            self.client_registration.as_str(),
82            self.authorization_grant.as_str(),
83            self.email.as_str(),
84        ]
85    }
86}
87
88#[derive(Debug)]
89pub struct Data {
90    server_name: String,
91
92    rest: Option<serde_json::Value>,
93}
94
95impl Data {
96    #[must_use]
97    pub fn new(server_name: String) -> Self {
98        Self {
99            server_name,
100            rest: None,
101        }
102    }
103
104    #[must_use]
105    pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
106        self.rest = Some(rest);
107        self
108    }
109
110    fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
111        let base = serde_json::json!({
112            "server_name": self.server_name,
113        });
114
115        if let Some(rest) = &self.rest {
116            merge_data(base, rest.clone())
117        } else {
118            Ok(base)
119        }
120    }
121}
122
123fn value_kind(value: &serde_json::Value) -> &'static str {
124    match value {
125        serde_json::Value::Object(_) => "object",
126        serde_json::Value::Array(_) => "array",
127        serde_json::Value::String(_) => "string",
128        serde_json::Value::Number(_) => "number",
129        serde_json::Value::Bool(_) => "boolean",
130        serde_json::Value::Null => "null",
131    }
132}
133
134fn merge_data(
135    mut left: serde_json::Value,
136    right: serde_json::Value,
137) -> Result<serde_json::Value, anyhow::Error> {
138    merge_data_rec(&mut left, right)?;
139    Ok(left)
140}
141
142fn merge_data_rec(
143    left: &mut serde_json::Value,
144    right: serde_json::Value,
145) -> Result<(), anyhow::Error> {
146    match (left, right) {
147        (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
148            for (key, value) in right {
149                if let Some(left_value) = left.get_mut(&key) {
150                    merge_data_rec(left_value, value)?;
151                } else {
152                    left.insert(key, value);
153                }
154            }
155        }
156        (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
157            left.extend(right);
158        }
159        // Other values override
160        (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
161            *left = right;
162        }
163        (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
164            *left = right;
165        }
166        (serde_json::Value::String(left), serde_json::Value::String(right)) => {
167            *left = right;
168        }
169
170        // Null gets overridden by anything
171        (left, right) if left.is_null() => *left = right,
172
173        // Null on the right makes the left value null
174        (left, right) if right.is_null() => *left = right,
175
176        (left, right) => anyhow::bail!(
177            "Cannot merge a {} into a {}",
178            value_kind(&right),
179            value_kind(left),
180        ),
181    }
182
183    Ok(())
184}
185
186struct DynamicData {
187    version: Option<Ulid>,
188    merged: serde_json::Value,
189}
190
191pub struct PolicyFactory {
192    engine: Engine,
193    module: Module,
194    data: Data,
195    dynamic_data: ArcSwap<DynamicData>,
196    entrypoints: Entrypoints,
197}
198
199impl PolicyFactory {
200    /// Load the policy from the given data source.
201    ///
202    /// # Errors
203    ///
204    /// Returns an error if the policy can't be loaded or instantiated.
205    #[tracing::instrument(name = "policy.load", skip(source))]
206    pub async fn load(
207        mut source: impl AsyncRead + std::marker::Unpin,
208        data: Data,
209        entrypoints: Entrypoints,
210    ) -> Result<Self, LoadError> {
211        let mut config = Config::default();
212        config.async_support(true);
213        config.cranelift_opt_level(OptLevel::SpeedAndSize);
214
215        let engine = Engine::new(&config).map_err(LoadError::Engine)?;
216
217        // Read and compile the module
218        let mut buf = Vec::new();
219        source.read_to_end(&mut buf).await?;
220        // Compilation is CPU-bound, so spawn that in a blocking task
221        let (engine, module) = tokio::task::spawn_blocking(move || {
222            let module = Module::new(&engine, buf)?;
223            anyhow::Ok((engine, module))
224        })
225        .await?
226        .map_err(LoadError::Compilation)?;
227
228        let merged = data.to_value().map_err(LoadError::InvalidData)?;
229        let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
230            version: None,
231            merged,
232        }));
233
234        let factory = Self {
235            engine,
236            module,
237            data,
238            dynamic_data,
239            entrypoints,
240        };
241
242        // Try to instantiate
243        factory
244            .instantiate()
245            .await
246            .map_err(LoadError::Instantiate)?;
247
248        Ok(factory)
249    }
250
251    /// Set the dynamic data for the policy.
252    ///
253    /// The `dynamic_data` object is merged with the static data given when the
254    /// policy was loaded.
255    ///
256    /// Returns `true` if the data was updated, `false` if the version
257    /// of the dynamic data was the same as the one we already have.
258    ///
259    /// # Errors
260    ///
261    /// Returns an error if the data can't be merged with the static data, or if
262    /// the policy can't be instantiated with the new data.
263    pub async fn set_dynamic_data(
264        &self,
265        dynamic_data: mas_data_model::PolicyData,
266    ) -> Result<bool, LoadError> {
267        // Check if the version of the dynamic data we have is the same as the one we're
268        // trying to set
269        if self.dynamic_data.load().version == Some(dynamic_data.id) {
270            // Don't do anything if the version is the same
271            return Ok(false);
272        }
273
274        let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
275        let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
276
277        // Try to instantiate with the new data
278        self.instantiate_with_data(&merged)
279            .await
280            .map_err(LoadError::Instantiate)?;
281
282        // If instantiation succeeds, swap the data
283        self.dynamic_data.store(Arc::new(DynamicData {
284            version: Some(dynamic_data.id),
285            merged,
286        }));
287
288        Ok(true)
289    }
290
291    /// Create a new policy instance.
292    ///
293    /// # Errors
294    ///
295    /// Returns an error if the policy can't be instantiated with the current
296    /// dynamic data.
297    #[tracing::instrument(name = "policy.instantiate", skip_all)]
298    pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
299        let data = self.dynamic_data.load();
300        self.instantiate_with_data(&data.merged).await
301    }
302
303    async fn instantiate_with_data(
304        &self,
305        data: &serde_json::Value,
306    ) -> Result<Policy, InstantiateError> {
307        let mut store = Store::new(&self.engine, ());
308        let runtime = Runtime::new(&mut store, &self.module)
309            .await
310            .map_err(InstantiateError::Runtime)?;
311
312        // Check that we have the required entrypoints
313        let policy_entrypoints = runtime.entrypoints();
314
315        for e in self.entrypoints.all() {
316            if !policy_entrypoints.contains(e) {
317                return Err(InstantiateError::MissingEntrypoint {
318                    entrypoint: e.to_owned(),
319                });
320            }
321        }
322
323        let instance = runtime
324            .with_data(&mut store, data)
325            .await
326            .map_err(InstantiateError::LoadData)?;
327
328        Ok(Policy {
329            store,
330            instance,
331            entrypoints: self.entrypoints.clone(),
332        })
333    }
334}
335
336pub struct Policy {
337    store: Store<()>,
338    instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
339    entrypoints: Entrypoints,
340}
341
342#[derive(Debug, Error)]
343#[error("failed to evaluate policy")]
344pub enum EvaluationError {
345    Serialization(#[from] serde_json::Error),
346    Evaluation(#[from] anyhow::Error),
347}
348
349impl Policy {
350    /// Evaluate the 'email' entrypoint.
351    ///
352    /// # Errors
353    ///
354    /// Returns an error if the policy engine fails to evaluate the entrypoint.
355    #[tracing::instrument(
356        name = "policy.evaluate_email",
357        skip_all,
358        fields(
359            %input.email,
360        ),
361    )]
362    pub async fn evaluate_email(
363        &mut self,
364        input: EmailInput<'_>,
365    ) -> Result<EvaluationResult, EvaluationError> {
366        let [res]: [EvaluationResult; 1] = self
367            .instance
368            .evaluate(&mut self.store, &self.entrypoints.email, &input)
369            .await?;
370
371        Ok(res)
372    }
373
374    /// Evaluate the 'register' entrypoint.
375    ///
376    /// # Errors
377    ///
378    /// Returns an error if the policy engine fails to evaluate the entrypoint.
379    #[tracing::instrument(
380        name = "policy.evaluate.register",
381        skip_all,
382        fields(
383            ?input.registration_method,
384            input.username = input.username,
385            input.email = input.email,
386        ),
387    )]
388    pub async fn evaluate_register(
389        &mut self,
390        input: RegisterInput<'_>,
391    ) -> Result<EvaluationResult, EvaluationError> {
392        let [res]: [EvaluationResult; 1] = self
393            .instance
394            .evaluate(&mut self.store, &self.entrypoints.register, &input)
395            .await?;
396
397        Ok(res)
398    }
399
400    /// Evaluate the 'client_registration' entrypoint.
401    ///
402    /// # Errors
403    ///
404    /// Returns an error if the policy engine fails to evaluate the entrypoint.
405    #[tracing::instrument(skip(self))]
406    pub async fn evaluate_client_registration(
407        &mut self,
408        input: ClientRegistrationInput<'_>,
409    ) -> Result<EvaluationResult, EvaluationError> {
410        let [res]: [EvaluationResult; 1] = self
411            .instance
412            .evaluate(
413                &mut self.store,
414                &self.entrypoints.client_registration,
415                &input,
416            )
417            .await?;
418
419        Ok(res)
420    }
421
422    /// Evaluate the 'authorization_grant' entrypoint.
423    ///
424    /// # Errors
425    ///
426    /// Returns an error if the policy engine fails to evaluate the entrypoint.
427    #[tracing::instrument(
428        name = "policy.evaluate.authorization_grant",
429        skip_all,
430        fields(
431            %input.scope,
432            %input.client.id,
433        ),
434    )]
435    pub async fn evaluate_authorization_grant(
436        &mut self,
437        input: AuthorizationGrantInput<'_>,
438    ) -> Result<EvaluationResult, EvaluationError> {
439        let [res]: [EvaluationResult; 1] = self
440            .instance
441            .evaluate(
442                &mut self.store,
443                &self.entrypoints.authorization_grant,
444                &input,
445            )
446            .await?;
447
448        Ok(res)
449    }
450}
451
452#[cfg(test)]
453mod tests {
454
455    use std::time::SystemTime;
456
457    use super::*;
458
459    #[tokio::test]
460    async fn test_register() {
461        let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
462            "allowed_domains": ["element.io", "*.element.io"],
463            "banned_domains": ["staging.element.io"],
464        }));
465
466        #[allow(clippy::disallowed_types)]
467        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
468            .join("..")
469            .join("..")
470            .join("policies")
471            .join("policy.wasm");
472
473        let file = tokio::fs::File::open(path).await.unwrap();
474
475        let entrypoints = Entrypoints {
476            register: "register/violation".to_owned(),
477            client_registration: "client_registration/violation".to_owned(),
478            authorization_grant: "authorization_grant/violation".to_owned(),
479            email: "email/violation".to_owned(),
480        };
481
482        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
483
484        let mut policy = factory.instantiate().await.unwrap();
485
486        let res = policy
487            .evaluate_register(RegisterInput {
488                registration_method: RegistrationMethod::Password,
489                username: "hello",
490                email: Some("hello@example.com"),
491                requester: Requester {
492                    ip_address: None,
493                    user_agent: None,
494                },
495            })
496            .await
497            .unwrap();
498        assert!(!res.valid());
499
500        let res = policy
501            .evaluate_register(RegisterInput {
502                registration_method: RegistrationMethod::Password,
503                username: "hello",
504                email: Some("hello@foo.element.io"),
505                requester: Requester {
506                    ip_address: None,
507                    user_agent: None,
508                },
509            })
510            .await
511            .unwrap();
512        assert!(res.valid());
513
514        let res = policy
515            .evaluate_register(RegisterInput {
516                registration_method: RegistrationMethod::Password,
517                username: "hello",
518                email: Some("hello@staging.element.io"),
519                requester: Requester {
520                    ip_address: None,
521                    user_agent: None,
522                },
523            })
524            .await
525            .unwrap();
526        assert!(!res.valid());
527    }
528
529    #[tokio::test]
530    async fn test_dynamic_data() {
531        let data = Data::new("example.com".to_owned());
532
533        #[allow(clippy::disallowed_types)]
534        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
535            .join("..")
536            .join("..")
537            .join("policies")
538            .join("policy.wasm");
539
540        let file = tokio::fs::File::open(path).await.unwrap();
541
542        let entrypoints = Entrypoints {
543            register: "register/violation".to_owned(),
544            client_registration: "client_registration/violation".to_owned(),
545            authorization_grant: "authorization_grant/violation".to_owned(),
546            email: "email/violation".to_owned(),
547        };
548
549        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
550
551        let mut policy = factory.instantiate().await.unwrap();
552
553        let res = policy
554            .evaluate_register(RegisterInput {
555                registration_method: RegistrationMethod::Password,
556                username: "hello",
557                email: Some("hello@example.com"),
558                requester: Requester {
559                    ip_address: None,
560                    user_agent: None,
561                },
562            })
563            .await
564            .unwrap();
565        assert!(res.valid());
566
567        // Update the policy data
568        factory
569            .set_dynamic_data(mas_data_model::PolicyData {
570                id: Ulid::nil(),
571                created_at: SystemTime::now().into(),
572                data: serde_json::json!({
573                    "emails": {
574                        "banned_addresses": {
575                            "substrings": ["hello"]
576                        }
577                    }
578                }),
579            })
580            .await
581            .unwrap();
582        let mut policy = factory.instantiate().await.unwrap();
583        let res = policy
584            .evaluate_register(RegisterInput {
585                registration_method: RegistrationMethod::Password,
586                username: "hello",
587                email: Some("hello@example.com"),
588                requester: Requester {
589                    ip_address: None,
590                    user_agent: None,
591                },
592            })
593            .await
594            .unwrap();
595        assert!(!res.valid());
596    }
597
598    #[tokio::test]
599    async fn test_big_dynamic_data() {
600        let data = Data::new("example.com".to_owned());
601
602        #[allow(clippy::disallowed_types)]
603        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
604            .join("..")
605            .join("..")
606            .join("policies")
607            .join("policy.wasm");
608
609        let file = tokio::fs::File::open(path).await.unwrap();
610
611        let entrypoints = Entrypoints {
612            register: "register/violation".to_owned(),
613            client_registration: "client_registration/violation".to_owned(),
614            authorization_grant: "authorization_grant/violation".to_owned(),
615            email: "email/violation".to_owned(),
616        };
617
618        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
619
620        // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
621        // characters including the quotes and a comma.
622        let data: Vec<String> = (0..(1024 * 1024 / 8))
623            .map(|i| format!("{:05}", i % 100_000))
624            .collect();
625        let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
626        factory
627            .set_dynamic_data(mas_data_model::PolicyData {
628                id: Ulid::nil(),
629                created_at: SystemTime::now().into(),
630                data: json,
631            })
632            .await
633            .unwrap();
634
635        // Try instantiating the policy, make sure 5-digit numbers are banned from email
636        // addresses
637        let mut policy = factory.instantiate().await.unwrap();
638        let res = policy
639            .evaluate_register(RegisterInput {
640                registration_method: RegistrationMethod::Password,
641                username: "hello",
642                email: Some("12345@example.com"),
643                requester: Requester {
644                    ip_address: None,
645                    user_agent: None,
646                },
647            })
648            .await
649            .unwrap();
650        assert!(!res.valid());
651    }
652
653    #[test]
654    fn test_merge() {
655        use serde_json::json as j;
656
657        // Merging objects
658        let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
659        assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
660
661        // Override a value of the same type
662        let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
663        assert_eq!(res, j!({"hello": "john"}));
664
665        let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
666        assert_eq!(res, j!({"hello": false}));
667
668        let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
669        assert_eq!(res, j!({"hello": 42}));
670
671        // Override a value of a different type
672        merge_data(j!({"hello": "world"}), j!({"hello": 123}))
673            .expect_err("Can't merge different types");
674
675        // Merge arrays
676        let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
677        assert_eq!(res, j!({"hello": ["world", "john"]}));
678
679        // Null overrides a value
680        let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
681        assert_eq!(res, j!({"hello": null}));
682
683        // Null gets overridden by a value
684        let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
685        assert_eq!(res, j!({"hello": "world"}));
686
687        // Objects get deeply merged
688        let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
689        assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
690    }
691}