mas_matrix/
mock.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16    sub: String,
17    avatar_url: Option<String>,
18    displayname: Option<String>,
19    devices: HashSet<String>,
20    emails: Option<Vec<String>>,
21    cross_signing_reset_allowed: bool,
22    deactivated: bool,
23}
24
25/// A mock implementation of a [`HomeserverConnection`], which never fails and
26/// doesn't do anything.
27pub struct HomeserverConnection {
28    homeserver: String,
29    users: RwLock<HashMap<String, MockUser>>,
30    reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34    /// A valid bearer token that will be accepted by
35    /// [`crate::HomeserverConnection::verify_token`].
36    pub const VALID_BEARER_TOKEN: &str = "mock_homeserver_bearer_token";
37
38    /// Create a new mock connection.
39    pub fn new<H>(homeserver: H) -> Self
40    where
41        H: Into<String>,
42    {
43        Self {
44            homeserver: homeserver.into(),
45            users: RwLock::new(HashMap::new()),
46            reserved_localparts: RwLock::new(HashSet::new()),
47        }
48    }
49
50    pub async fn reserve_localpart(&self, localpart: &'static str) {
51        self.reserved_localparts.write().await.insert(localpart);
52    }
53}
54
55#[async_trait]
56impl crate::HomeserverConnection for HomeserverConnection {
57    fn homeserver(&self) -> &str {
58        &self.homeserver
59    }
60
61    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
62        Ok(token == Self::VALID_BEARER_TOKEN)
63    }
64
65    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
66        let mxid = self.mxid(localpart);
67        let users = self.users.read().await;
68        let user = users.get(&mxid).context("User not found")?;
69        Ok(MatrixUser {
70            displayname: user.displayname.clone(),
71            avatar_url: user.avatar_url.clone(),
72            deactivated: user.deactivated,
73        })
74    }
75
76    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
77        let mut users = self.users.write().await;
78        let mxid = self.mxid(request.localpart());
79        let inserted = !users.contains_key(&mxid);
80        let user = users.entry(mxid).or_insert(MockUser {
81            sub: request.sub().to_owned(),
82            avatar_url: None,
83            displayname: None,
84            devices: HashSet::new(),
85            emails: None,
86            cross_signing_reset_allowed: false,
87            deactivated: false,
88        });
89
90        anyhow::ensure!(
91            user.sub == request.sub(),
92            "User already provisioned with different sub"
93        );
94
95        request.on_emails(|emails| {
96            user.emails = emails.map(ToOwned::to_owned);
97        });
98
99        request.on_displayname(|displayname| {
100            user.displayname = displayname.map(ToOwned::to_owned);
101        });
102
103        request.on_avatar_url(|avatar_url| {
104            user.avatar_url = avatar_url.map(ToOwned::to_owned);
105        });
106
107        Ok(inserted)
108    }
109
110    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
111        if self.reserved_localparts.read().await.contains(localpart) {
112            return Ok(false);
113        }
114
115        let mxid = self.mxid(localpart);
116        let users = self.users.read().await;
117        Ok(!users.contains_key(&mxid))
118    }
119
120    async fn upsert_device(
121        &self,
122        localpart: &str,
123        device_id: &str,
124        _initial_display_name: Option<&str>,
125    ) -> Result<(), anyhow::Error> {
126        let mxid = self.mxid(localpart);
127        let mut users = self.users.write().await;
128        let user = users.get_mut(&mxid).context("User not found")?;
129        user.devices.insert(device_id.to_owned());
130        Ok(())
131    }
132
133    async fn update_device_display_name(
134        &self,
135        localpart: &str,
136        device_id: &str,
137        _display_name: &str,
138    ) -> Result<(), anyhow::Error> {
139        let mxid = self.mxid(localpart);
140        let mut users = self.users.write().await;
141        let user = users.get_mut(&mxid).context("User not found")?;
142        user.devices.get(device_id).context("Device not found")?;
143        Ok(())
144    }
145
146    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
147        let mxid = self.mxid(localpart);
148        let mut users = self.users.write().await;
149        let user = users.get_mut(&mxid).context("User not found")?;
150        user.devices.remove(device_id);
151        Ok(())
152    }
153
154    async fn sync_devices(
155        &self,
156        localpart: &str,
157        devices: HashSet<String>,
158    ) -> Result<(), anyhow::Error> {
159        let mxid = self.mxid(localpart);
160        let mut users = self.users.write().await;
161        let user = users.get_mut(&mxid).context("User not found")?;
162        user.devices = devices;
163        Ok(())
164    }
165
166    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
167        let mxid = self.mxid(localpart);
168        let mut users = self.users.write().await;
169        let user = users.get_mut(&mxid).context("User not found")?;
170        user.devices.clear();
171        user.emails = None;
172        user.deactivated = true;
173        if erase {
174            user.avatar_url = None;
175            user.displayname = None;
176        }
177
178        Ok(())
179    }
180
181    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
182        let mxid = self.mxid(localpart);
183        let mut users = self.users.write().await;
184        let user = users.get_mut(&mxid).context("User not found")?;
185        user.deactivated = false;
186
187        Ok(())
188    }
189
190    async fn set_displayname(
191        &self,
192        localpart: &str,
193        displayname: &str,
194    ) -> Result<(), anyhow::Error> {
195        let mxid = self.mxid(localpart);
196        let mut users = self.users.write().await;
197        let user = users.get_mut(&mxid).context("User not found")?;
198        user.displayname = Some(displayname.to_owned());
199        Ok(())
200    }
201
202    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
203        let mxid = self.mxid(localpart);
204        let mut users = self.users.write().await;
205        let user = users.get_mut(&mxid).context("User not found")?;
206        user.displayname = None;
207        Ok(())
208    }
209
210    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
211        let mxid = self.mxid(localpart);
212        let mut users = self.users.write().await;
213        let user = users.get_mut(&mxid).context("User not found")?;
214        user.cross_signing_reset_allowed = true;
215        Ok(())
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::HomeserverConnection as _;
223
224    #[tokio::test]
225    async fn test_mock_connection() {
226        let conn = HomeserverConnection::new("example.org");
227
228        let mxid = "@test:example.org";
229        let device = "test";
230        assert_eq!(conn.homeserver(), "example.org");
231        assert_eq!(conn.mxid("test"), mxid);
232
233        assert!(conn.query_user("test").await.is_err());
234        assert!(conn.upsert_device("test", device, None).await.is_err());
235        assert!(conn.delete_device("test", device).await.is_err());
236
237        let request = ProvisionRequest::new("test", "test")
238            .set_displayname("Test User".into())
239            .set_avatar_url("mxc://example.org/1234567890".into())
240            .set_emails(vec!["test@example.org".to_owned()]);
241
242        let inserted = conn.provision_user(&request).await.unwrap();
243        assert!(inserted);
244
245        let user = conn.query_user("test").await.unwrap();
246        assert_eq!(user.displayname, Some("Test User".into()));
247        assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
248
249        // Set the displayname again
250        assert!(conn.set_displayname("test", "John").await.is_ok());
251
252        let user = conn.query_user("test").await.unwrap();
253        assert_eq!(user.displayname, Some("John".into()));
254
255        // Unset the displayname
256        assert!(conn.unset_displayname("test").await.is_ok());
257
258        let user = conn.query_user("test").await.unwrap();
259        assert_eq!(user.displayname, None);
260
261        // Deleting a non-existent device should not fail
262        assert!(conn.delete_device("test", device).await.is_ok());
263
264        // Create the device
265        assert!(conn.upsert_device("test", device, None).await.is_ok());
266        // Create the same device again
267        assert!(conn.upsert_device("test", device, None).await.is_ok());
268
269        // XXX: there is no API to query devices yet in the trait
270        // Delete the device
271        assert!(conn.delete_device("test", device).await.is_ok());
272
273        // The user we just created should be not available
274        assert!(!conn.is_localpart_available("test").await.unwrap());
275        // But another user should be
276        assert!(conn.is_localpart_available("alice").await.unwrap());
277
278        // Reserve the localpart, it should not be available anymore
279        conn.reserve_localpart("alice").await;
280        assert!(!conn.is_localpart_available("alice").await.unwrap());
281    }
282}