1use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16 sub: String,
17 avatar_url: Option<String>,
18 displayname: Option<String>,
19 devices: HashSet<String>,
20 emails: Option<Vec<String>>,
21 cross_signing_reset_allowed: bool,
22 deactivated: bool,
23}
24
25pub struct HomeserverConnection {
28 homeserver: String,
29 users: RwLock<HashMap<String, MockUser>>,
30 reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34 pub const VALID_BEARER_TOKEN: &str = "mock_homeserver_bearer_token";
37
38 pub fn new<H>(homeserver: H) -> Self
40 where
41 H: Into<String>,
42 {
43 Self {
44 homeserver: homeserver.into(),
45 users: RwLock::new(HashMap::new()),
46 reserved_localparts: RwLock::new(HashSet::new()),
47 }
48 }
49
50 pub async fn reserve_localpart(&self, localpart: &'static str) {
51 self.reserved_localparts.write().await.insert(localpart);
52 }
53}
54
55#[async_trait]
56impl crate::HomeserverConnection for HomeserverConnection {
57 fn homeserver(&self) -> &str {
58 &self.homeserver
59 }
60
61 async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
62 Ok(token == Self::VALID_BEARER_TOKEN)
63 }
64
65 async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
66 let mxid = self.mxid(localpart);
67 let users = self.users.read().await;
68 let user = users.get(&mxid).context("User not found")?;
69 Ok(MatrixUser {
70 displayname: user.displayname.clone(),
71 avatar_url: user.avatar_url.clone(),
72 deactivated: user.deactivated,
73 })
74 }
75
76 async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
77 let mut users = self.users.write().await;
78 let mxid = self.mxid(request.localpart());
79 let inserted = !users.contains_key(&mxid);
80 let user = users.entry(mxid).or_insert(MockUser {
81 sub: request.sub().to_owned(),
82 avatar_url: None,
83 displayname: None,
84 devices: HashSet::new(),
85 emails: None,
86 cross_signing_reset_allowed: false,
87 deactivated: false,
88 });
89
90 anyhow::ensure!(
91 user.sub == request.sub(),
92 "User already provisioned with different sub"
93 );
94
95 request.on_emails(|emails| {
96 user.emails = emails.map(ToOwned::to_owned);
97 });
98
99 request.on_displayname(|displayname| {
100 user.displayname = displayname.map(ToOwned::to_owned);
101 });
102
103 request.on_avatar_url(|avatar_url| {
104 user.avatar_url = avatar_url.map(ToOwned::to_owned);
105 });
106
107 Ok(inserted)
108 }
109
110 async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
111 if self.reserved_localparts.read().await.contains(localpart) {
112 return Ok(false);
113 }
114
115 let mxid = self.mxid(localpart);
116 let users = self.users.read().await;
117 Ok(!users.contains_key(&mxid))
118 }
119
120 async fn upsert_device(
121 &self,
122 localpart: &str,
123 device_id: &str,
124 _initial_display_name: Option<&str>,
125 ) -> Result<(), anyhow::Error> {
126 let mxid = self.mxid(localpart);
127 let mut users = self.users.write().await;
128 let user = users.get_mut(&mxid).context("User not found")?;
129 user.devices.insert(device_id.to_owned());
130 Ok(())
131 }
132
133 async fn update_device_display_name(
134 &self,
135 localpart: &str,
136 device_id: &str,
137 _display_name: &str,
138 ) -> Result<(), anyhow::Error> {
139 let mxid = self.mxid(localpart);
140 let mut users = self.users.write().await;
141 let user = users.get_mut(&mxid).context("User not found")?;
142 user.devices.get(device_id).context("Device not found")?;
143 Ok(())
144 }
145
146 async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
147 let mxid = self.mxid(localpart);
148 let mut users = self.users.write().await;
149 let user = users.get_mut(&mxid).context("User not found")?;
150 user.devices.remove(device_id);
151 Ok(())
152 }
153
154 async fn sync_devices(
155 &self,
156 localpart: &str,
157 devices: HashSet<String>,
158 ) -> Result<(), anyhow::Error> {
159 let mxid = self.mxid(localpart);
160 let mut users = self.users.write().await;
161 let user = users.get_mut(&mxid).context("User not found")?;
162 user.devices = devices;
163 Ok(())
164 }
165
166 async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
167 let mxid = self.mxid(localpart);
168 let mut users = self.users.write().await;
169 let user = users.get_mut(&mxid).context("User not found")?;
170 user.devices.clear();
171 user.emails = None;
172 user.deactivated = true;
173 if erase {
174 user.avatar_url = None;
175 user.displayname = None;
176 }
177
178 Ok(())
179 }
180
181 async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
182 let mxid = self.mxid(localpart);
183 let mut users = self.users.write().await;
184 let user = users.get_mut(&mxid).context("User not found")?;
185 user.deactivated = false;
186
187 Ok(())
188 }
189
190 async fn set_displayname(
191 &self,
192 localpart: &str,
193 displayname: &str,
194 ) -> Result<(), anyhow::Error> {
195 let mxid = self.mxid(localpart);
196 let mut users = self.users.write().await;
197 let user = users.get_mut(&mxid).context("User not found")?;
198 user.displayname = Some(displayname.to_owned());
199 Ok(())
200 }
201
202 async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
203 let mxid = self.mxid(localpart);
204 let mut users = self.users.write().await;
205 let user = users.get_mut(&mxid).context("User not found")?;
206 user.displayname = None;
207 Ok(())
208 }
209
210 async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
211 let mxid = self.mxid(localpart);
212 let mut users = self.users.write().await;
213 let user = users.get_mut(&mxid).context("User not found")?;
214 user.cross_signing_reset_allowed = true;
215 Ok(())
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::HomeserverConnection as _;
223
224 #[tokio::test]
225 async fn test_mock_connection() {
226 let conn = HomeserverConnection::new("example.org");
227
228 let mxid = "@test:example.org";
229 let device = "test";
230 assert_eq!(conn.homeserver(), "example.org");
231 assert_eq!(conn.mxid("test"), mxid);
232
233 assert!(conn.query_user("test").await.is_err());
234 assert!(conn.upsert_device("test", device, None).await.is_err());
235 assert!(conn.delete_device("test", device).await.is_err());
236
237 let request = ProvisionRequest::new("test", "test")
238 .set_displayname("Test User".into())
239 .set_avatar_url("mxc://example.org/1234567890".into())
240 .set_emails(vec!["test@example.org".to_owned()]);
241
242 let inserted = conn.provision_user(&request).await.unwrap();
243 assert!(inserted);
244
245 let user = conn.query_user("test").await.unwrap();
246 assert_eq!(user.displayname, Some("Test User".into()));
247 assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
248
249 assert!(conn.set_displayname("test", "John").await.is_ok());
251
252 let user = conn.query_user("test").await.unwrap();
253 assert_eq!(user.displayname, Some("John".into()));
254
255 assert!(conn.unset_displayname("test").await.is_ok());
257
258 let user = conn.query_user("test").await.unwrap();
259 assert_eq!(user.displayname, None);
260
261 assert!(conn.delete_device("test", device).await.is_ok());
263
264 assert!(conn.upsert_device("test", device, None).await.is_ok());
266 assert!(conn.upsert_device("test", device, None).await.is_ok());
268
269 assert!(conn.delete_device("test", device).await.is_ok());
272
273 assert!(!conn.is_localpart_available("test").await.unwrap());
275 assert!(conn.is_localpart_available("alice").await.unwrap());
277
278 conn.reserve_localpart("alice").await;
280 assert!(!conn.is_localpart_available("alice").await.unwrap());
281 }
282}