mas_matrix/
mock.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16    sub: String,
17    avatar_url: Option<String>,
18    displayname: Option<String>,
19    devices: HashSet<String>,
20    emails: Option<Vec<String>>,
21    cross_signing_reset_allowed: bool,
22    deactivated: bool,
23}
24
25/// A mock implementation of a [`HomeserverConnection`], which never fails and
26/// doesn't do anything.
27pub struct HomeserverConnection {
28    homeserver: String,
29    users: RwLock<HashMap<String, MockUser>>,
30    reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34    /// Create a new mock connection.
35    pub fn new<H>(homeserver: H) -> Self
36    where
37        H: Into<String>,
38    {
39        Self {
40            homeserver: homeserver.into(),
41            users: RwLock::new(HashMap::new()),
42            reserved_localparts: RwLock::new(HashSet::new()),
43        }
44    }
45
46    pub async fn reserve_localpart(&self, localpart: &'static str) {
47        self.reserved_localparts.write().await.insert(localpart);
48    }
49}
50
51#[async_trait]
52impl crate::HomeserverConnection for HomeserverConnection {
53    fn homeserver(&self) -> &str {
54        &self.homeserver
55    }
56
57    async fn query_user(&self, mxid: &str) -> Result<MatrixUser, anyhow::Error> {
58        let users = self.users.read().await;
59        let user = users.get(mxid).context("User not found")?;
60        Ok(MatrixUser {
61            displayname: user.displayname.clone(),
62            avatar_url: user.avatar_url.clone(),
63            deactivated: user.deactivated,
64        })
65    }
66
67    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
68        let mut users = self.users.write().await;
69        let inserted = !users.contains_key(request.mxid());
70        let user = users.entry(request.mxid().to_owned()).or_insert(MockUser {
71            sub: request.sub().to_owned(),
72            avatar_url: None,
73            displayname: None,
74            devices: HashSet::new(),
75            emails: None,
76            cross_signing_reset_allowed: false,
77            deactivated: false,
78        });
79
80        anyhow::ensure!(
81            user.sub == request.sub(),
82            "User already provisioned with different sub"
83        );
84
85        request.on_emails(|emails| {
86            user.emails = emails.map(ToOwned::to_owned);
87        });
88
89        request.on_displayname(|displayname| {
90            user.displayname = displayname.map(ToOwned::to_owned);
91        });
92
93        request.on_avatar_url(|avatar_url| {
94            user.avatar_url = avatar_url.map(ToOwned::to_owned);
95        });
96
97        Ok(inserted)
98    }
99
100    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
101        if self.reserved_localparts.read().await.contains(localpart) {
102            return Ok(false);
103        }
104
105        let mxid = self.mxid(localpart);
106        let users = self.users.read().await;
107        Ok(!users.contains_key(&mxid))
108    }
109
110    async fn create_device(
111        &self,
112        mxid: &str,
113        device_id: &str,
114        _initial_display_name: Option<&str>,
115    ) -> Result<(), anyhow::Error> {
116        let mut users = self.users.write().await;
117        let user = users.get_mut(mxid).context("User not found")?;
118        user.devices.insert(device_id.to_owned());
119        Ok(())
120    }
121
122    async fn update_device_display_name(
123        &self,
124        mxid: &str,
125        device_id: &str,
126        _display_name: &str,
127    ) -> Result<(), anyhow::Error> {
128        let mut users = self.users.write().await;
129        let user = users.get_mut(mxid).context("User not found")?;
130        user.devices.get(device_id).context("Device not found")?;
131        Ok(())
132    }
133
134    async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> {
135        let mut users = self.users.write().await;
136        let user = users.get_mut(mxid).context("User not found")?;
137        user.devices.remove(device_id);
138        Ok(())
139    }
140
141    async fn sync_devices(
142        &self,
143        mxid: &str,
144        devices: HashSet<String>,
145    ) -> Result<(), anyhow::Error> {
146        let mut users = self.users.write().await;
147        let user = users.get_mut(mxid).context("User not found")?;
148        user.devices = devices;
149        Ok(())
150    }
151
152    async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> {
153        let mut users = self.users.write().await;
154        let user = users.get_mut(mxid).context("User not found")?;
155        user.devices.clear();
156        user.emails = None;
157        user.deactivated = true;
158        if erase {
159            user.avatar_url = None;
160            user.displayname = None;
161        }
162
163        Ok(())
164    }
165
166    async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error> {
167        let mut users = self.users.write().await;
168        let user = users.get_mut(mxid).context("User not found")?;
169        user.deactivated = false;
170
171        Ok(())
172    }
173
174    async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> {
175        let mut users = self.users.write().await;
176        let user = users.get_mut(mxid).context("User not found")?;
177        user.displayname = Some(displayname.to_owned());
178        Ok(())
179    }
180
181    async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> {
182        let mut users = self.users.write().await;
183        let user = users.get_mut(mxid).context("User not found")?;
184        user.displayname = None;
185        Ok(())
186    }
187
188    async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> {
189        let mut users = self.users.write().await;
190        let user = users.get_mut(mxid).context("User not found")?;
191        user.cross_signing_reset_allowed = true;
192        Ok(())
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::HomeserverConnection as _;
200
201    #[tokio::test]
202    async fn test_mock_connection() {
203        let conn = HomeserverConnection::new("example.org");
204
205        let mxid = "@test:example.org";
206        let device = "test";
207        assert_eq!(conn.homeserver(), "example.org");
208        assert_eq!(conn.mxid("test"), mxid);
209
210        assert!(conn.query_user(mxid).await.is_err());
211        assert!(conn.create_device(mxid, device, None).await.is_err());
212        assert!(conn.delete_device(mxid, device).await.is_err());
213
214        let request = ProvisionRequest::new("@test:example.org", "test")
215            .set_displayname("Test User".into())
216            .set_avatar_url("mxc://example.org/1234567890".into())
217            .set_emails(vec!["test@example.org".to_owned()]);
218
219        let inserted = conn.provision_user(&request).await.unwrap();
220        assert!(inserted);
221
222        let user = conn.query_user(mxid).await.unwrap();
223        assert_eq!(user.displayname, Some("Test User".into()));
224        assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
225
226        // Set the displayname again
227        assert!(conn.set_displayname(mxid, "John").await.is_ok());
228
229        let user = conn.query_user(mxid).await.unwrap();
230        assert_eq!(user.displayname, Some("John".into()));
231
232        // Unset the displayname
233        assert!(conn.unset_displayname(mxid).await.is_ok());
234
235        let user = conn.query_user(mxid).await.unwrap();
236        assert_eq!(user.displayname, None);
237
238        // Deleting a non-existent device should not fail
239        assert!(conn.delete_device(mxid, device).await.is_ok());
240
241        // Create the device
242        assert!(conn.create_device(mxid, device, None).await.is_ok());
243        // Create the same device again
244        assert!(conn.create_device(mxid, device, None).await.is_ok());
245
246        // XXX: there is no API to query devices yet in the trait
247        // Delete the device
248        assert!(conn.delete_device(mxid, device).await.is_ok());
249
250        // The user we just created should be not available
251        assert!(!conn.is_localpart_available("test").await.unwrap());
252        // But another user should be
253        assert!(conn.is_localpart_available("alice").await.unwrap());
254
255        // Reserve the localpart, it should not be available anymore
256        conn.reserve_localpart("alice").await;
257        assert!(!conn.is_localpart_available("alice").await.unwrap());
258    }
259}