1use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16 sub: String,
17 avatar_url: Option<String>,
18 displayname: Option<String>,
19 devices: HashSet<String>,
20 emails: Option<Vec<String>>,
21 cross_signing_reset_allowed: bool,
22 deactivated: bool,
23}
24
25pub struct HomeserverConnection {
28 homeserver: String,
29 users: RwLock<HashMap<String, MockUser>>,
30 reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34 pub fn new<H>(homeserver: H) -> Self
36 where
37 H: Into<String>,
38 {
39 Self {
40 homeserver: homeserver.into(),
41 users: RwLock::new(HashMap::new()),
42 reserved_localparts: RwLock::new(HashSet::new()),
43 }
44 }
45
46 pub async fn reserve_localpart(&self, localpart: &'static str) {
47 self.reserved_localparts.write().await.insert(localpart);
48 }
49}
50
51#[async_trait]
52impl crate::HomeserverConnection for HomeserverConnection {
53 fn homeserver(&self) -> &str {
54 &self.homeserver
55 }
56
57 async fn query_user(&self, mxid: &str) -> Result<MatrixUser, anyhow::Error> {
58 let users = self.users.read().await;
59 let user = users.get(mxid).context("User not found")?;
60 Ok(MatrixUser {
61 displayname: user.displayname.clone(),
62 avatar_url: user.avatar_url.clone(),
63 deactivated: user.deactivated,
64 })
65 }
66
67 async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
68 let mut users = self.users.write().await;
69 let inserted = !users.contains_key(request.mxid());
70 let user = users.entry(request.mxid().to_owned()).or_insert(MockUser {
71 sub: request.sub().to_owned(),
72 avatar_url: None,
73 displayname: None,
74 devices: HashSet::new(),
75 emails: None,
76 cross_signing_reset_allowed: false,
77 deactivated: false,
78 });
79
80 anyhow::ensure!(
81 user.sub == request.sub(),
82 "User already provisioned with different sub"
83 );
84
85 request.on_emails(|emails| {
86 user.emails = emails.map(ToOwned::to_owned);
87 });
88
89 request.on_displayname(|displayname| {
90 user.displayname = displayname.map(ToOwned::to_owned);
91 });
92
93 request.on_avatar_url(|avatar_url| {
94 user.avatar_url = avatar_url.map(ToOwned::to_owned);
95 });
96
97 Ok(inserted)
98 }
99
100 async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
101 if self.reserved_localparts.read().await.contains(localpart) {
102 return Ok(false);
103 }
104
105 let mxid = self.mxid(localpart);
106 let users = self.users.read().await;
107 Ok(!users.contains_key(&mxid))
108 }
109
110 async fn create_device(
111 &self,
112 mxid: &str,
113 device_id: &str,
114 _initial_display_name: Option<&str>,
115 ) -> Result<(), anyhow::Error> {
116 let mut users = self.users.write().await;
117 let user = users.get_mut(mxid).context("User not found")?;
118 user.devices.insert(device_id.to_owned());
119 Ok(())
120 }
121
122 async fn update_device_display_name(
123 &self,
124 mxid: &str,
125 device_id: &str,
126 _display_name: &str,
127 ) -> Result<(), anyhow::Error> {
128 let mut users = self.users.write().await;
129 let user = users.get_mut(mxid).context("User not found")?;
130 user.devices.get(device_id).context("Device not found")?;
131 Ok(())
132 }
133
134 async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> {
135 let mut users = self.users.write().await;
136 let user = users.get_mut(mxid).context("User not found")?;
137 user.devices.remove(device_id);
138 Ok(())
139 }
140
141 async fn sync_devices(
142 &self,
143 mxid: &str,
144 devices: HashSet<String>,
145 ) -> Result<(), anyhow::Error> {
146 let mut users = self.users.write().await;
147 let user = users.get_mut(mxid).context("User not found")?;
148 user.devices = devices;
149 Ok(())
150 }
151
152 async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> {
153 let mut users = self.users.write().await;
154 let user = users.get_mut(mxid).context("User not found")?;
155 user.devices.clear();
156 user.emails = None;
157 user.deactivated = true;
158 if erase {
159 user.avatar_url = None;
160 user.displayname = None;
161 }
162
163 Ok(())
164 }
165
166 async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error> {
167 let mut users = self.users.write().await;
168 let user = users.get_mut(mxid).context("User not found")?;
169 user.deactivated = false;
170
171 Ok(())
172 }
173
174 async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> {
175 let mut users = self.users.write().await;
176 let user = users.get_mut(mxid).context("User not found")?;
177 user.displayname = Some(displayname.to_owned());
178 Ok(())
179 }
180
181 async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> {
182 let mut users = self.users.write().await;
183 let user = users.get_mut(mxid).context("User not found")?;
184 user.displayname = None;
185 Ok(())
186 }
187
188 async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> {
189 let mut users = self.users.write().await;
190 let user = users.get_mut(mxid).context("User not found")?;
191 user.cross_signing_reset_allowed = true;
192 Ok(())
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::HomeserverConnection as _;
200
201 #[tokio::test]
202 async fn test_mock_connection() {
203 let conn = HomeserverConnection::new("example.org");
204
205 let mxid = "@test:example.org";
206 let device = "test";
207 assert_eq!(conn.homeserver(), "example.org");
208 assert_eq!(conn.mxid("test"), mxid);
209
210 assert!(conn.query_user(mxid).await.is_err());
211 assert!(conn.create_device(mxid, device, None).await.is_err());
212 assert!(conn.delete_device(mxid, device).await.is_err());
213
214 let request = ProvisionRequest::new("@test:example.org", "test")
215 .set_displayname("Test User".into())
216 .set_avatar_url("mxc://example.org/1234567890".into())
217 .set_emails(vec!["test@example.org".to_owned()]);
218
219 let inserted = conn.provision_user(&request).await.unwrap();
220 assert!(inserted);
221
222 let user = conn.query_user(mxid).await.unwrap();
223 assert_eq!(user.displayname, Some("Test User".into()));
224 assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
225
226 assert!(conn.set_displayname(mxid, "John").await.is_ok());
228
229 let user = conn.query_user(mxid).await.unwrap();
230 assert_eq!(user.displayname, Some("John".into()));
231
232 assert!(conn.unset_displayname(mxid).await.is_ok());
234
235 let user = conn.query_user(mxid).await.unwrap();
236 assert_eq!(user.displayname, None);
237
238 assert!(conn.delete_device(mxid, device).await.is_ok());
240
241 assert!(conn.create_device(mxid, device, None).await.is_ok());
243 assert!(conn.create_device(mxid, device, None).await.is_ok());
245
246 assert!(conn.delete_device(mxid, device).await.is_ok());
249
250 assert!(!conn.is_localpart_available("test").await.unwrap());
252 assert!(conn.is_localpart_available("alice").await.unwrap());
254
255 conn.reserve_localpart("alice").await;
257 assert!(!conn.is_localpart_available("alice").await.unwrap());
258 }
259}